Compare commits

..

23 Commits

Author SHA1 Message Date
Otto
156636c79d fix(backend): auto-correct content-type based on file signature instead of erroring
When uploading media files, the browser-declared content-type header
sometimes doesn't match the actual file content (e.g., user renames
a PNG to .jpg). Instead of rejecting these files, we now:

1. Detect the actual content type from file magic bytes
2. Log when auto-correction occurs for debugging
3. Use the detected type for storage and processing

This improves UX while maintaining security - we still validate that
files are legitimate images/videos, just trust the actual content
over the header.

Fixes: File signature does not match content type errors in Sentry
2026-03-04 09:24:18 +00:00
Nicholas Tindle
b9aac42056 Merge branch 'master' into dev 2026-02-26 13:39:34 -06:00
Otto
95651d33da feat(backend): add fpdf2 dependency for PDF operations in copilot executor (#12216)
Requested by @majdyz

Adds [fpdf2](https://github.com/py-pdf/fpdf2) (v2.8.6) to backend
dependencies to enable PDF generation and manipulation in the copilot
executor.

fpdf2 is a lightweight PDF generation library (no external dependencies,
pure Python) that allows creating PDFs with text, images, tables, and
more.
2026-02-26 18:21:34 +00:00
Zamil Majdy
b30418d833 fix(copilot): inject working directory into SDK prompt + workspace download links (#12215)
## Summary

- Replaces the static `_SDK_TOOL_SUPPLEMENT` placeholder path with
`_build_sdk_tool_supplement(cwd: str)` that injects the session-specific
working directory
- `sdk_cwd` is computed once via `_make_sdk_cwd(session_id)`,
`os.makedirs` is called after lock acquisition (inside the protected
`try/finally`), and the same variable is used everywhere — no drift
between prompt and execution directory
- Added `ValueError`/`OSError` error handling for cwd preparation with
proper `StreamError` emission
- Teaches the SDK agent how to share workspace files with the user via
`workspace://` Markdown links (images render inline, videos render with
player controls, other files as download links)
- `WorkspaceWriteResponse` now includes `download_url` (pre-formatted
`workspace://file_id#mime` string) and a normalised `mime_type` field
(MIME parameters stripped, lowercased)
- Frontend: workspace `workspace://` regular links now resolve to
absolute URLs so Streamdown's "Copy link" copies the full URL
- Frontend: Streamdown's "Open link" button colour overridden to match
the design system (violet accent) — previously near-invisible in dark
mode due to `--primary` resolving to near-white

## Motivation

The SDK agent was seeing a hardcoded placeholder path in the system
prompt instead of the real working directory, causing it to reference
wrong paths in tool calls. Additionally, there was no guidance for the
agent on how to share files it writes to the workspace with the user in
chat.

## Test plan

- [ ] CI green (test 3.11 / 3.12 / 3.13)
- [ ] Start a copilot session with `CHAT_USE_CLAUDE_AGENT_SDK=true` and
verify the agent references the correct `sdk_cwd` path in its tool calls
- [ ] Ask the agent to write a file and confirm it responds with a
clickable download link / inline image using the `workspace://` syntax
- [ ] Verify the "Open link" button in the Streamdown external-link
dialog is visible in both light and dark mode
- [ ] Click "Copy link" on a workspace file link and confirm it copies
the full URL (including host)
2026-02-26 17:26:19 +00:00
Otto
ed729ddbe2 feat(copilot): Wait for agent execution completion (#12147)
Adds the ability for CoPilot to wait for agent execution to complete
before returning results.

Closes SECRT-2003.

## Changes

### New: `execution_utils.py`
- `wait_for_execution()` — uses Redis pubsub to wait for execution to
reach terminal state
- `TERMINAL_STATUSES` — shared frozenset of completed/failed/terminated
- `PAUSED_STATUSES` — handles REVIEW (human-in-the-loop) as a
stop-waiting state
- `get_execution_outputs()` — helper to extract outputs

### `run_agent.py`
- New `wait_for_result` parameter (0-300 seconds)
- When >0, waits for execution to complete and returns outputs directly
- Handles completed, failed, terminated, review, and timeout states with
appropriate responses

### `agent_output.py` (view_agent_output)
- New `wait_if_running` parameter (0-300 seconds)
- Includes running/queued/review executions when waiting is requested
- Status-aware response messages (completed, failed, running, review,
etc.)

## How it works
1. After starting execution, subscribes to Redis pubsub channel for
execution events
2. Re-checks DB after subscribing to close the race window
3. `asyncio.wait_for` enforces the timeout
4. On completion: returns full outputs via `AgentOutputResponse`
5. On timeout: returns current state with guidance to check again later
6. On error/terminated: returns `ErrorResponse` with details
7. Redis connection always cleaned up via `finally` block

## Testing

- [x] Run an agent with `wait_for_result=0` — should return immediately
with execution ID (existing behavior)
- [x] Run a fast agent with `wait_for_result=60` — should return
completed outputs
- [x] Run a slow agent with `wait_for_result=5` — should timeout and
return current status
- [x] Use `view_agent_output` with `wait_if_running=0` on a completed
execution — should return outputs
- [x] Use `view_agent_output` with `wait_if_running=30` on a running
execution — should wait and return
- [ ] ~~Verify Redis connections are cleaned up (no leaked pubsub
connections after timeout)~~
- [ ] ~~Test with a failed execution — should return error response~~
- [ ] ~~Test with a terminated execution — should return error response
(not "still running")~~

## Collaboration

This PR was developed in collaboration with @Pwuts.

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 16:41:33 +00:00
Otto
8c7030af0b fix(copilot): handle 'all' keyword in find_library_agent tool (#12138)
When users ask CoPilot to "show all my agents" or similar, the LLM was
passing the literal string "all" as a search query to
`find_library_agent`, which matched no agents because there's no agent
named "all". (issue:
[SECRT-2002](https://linear.app/autogpt/issue/SECRT-2002))

## Changes

- **Make `query` parameter optional** in `FindLibraryAgentTool` - users
can now omit it to list all agents
- **Add special keyword handling** - keywords like "all", "*",
"everything", "any", or empty string are treated as "list all" rather
than literal searches
- **Update response messages** - differentiate between "listing all
agents" vs "searching for X"

## Example

Before:
```
User: Show me all my agents
CoPilot: find_library_agent(query="all")
Result: No agents matching 'all' found in your library
```

After:
```
User: Show me all my agents  
CoPilot: find_library_agent(query="all") OR find_library_agent()
Result: Found 5 agents in your library
```

## Testing

- [x] Test with "show me all my agents" prompt
- [x] Test with empty query
- [x] Test with specific search terms (should still work as before)

## Collaboration

This PR was developed in collaboration with @Pwuts.
2026-02-26 16:07:40 +00:00
Otto
195b14286a fix(frontend): fix Streamdown link safety modal and add origin check (#12209)
Requested by @ntindle

Fixes the Streamdown link safety modal in CoPilot with three changes:

**1. Fix invisible "Open link" button (HIGH)**
Added Streamdown's dist directory to the Tailwind content scan in
`tailwind.config.ts`. Previously, Tailwind was only scanning
`./src/**/*.{ts,tsx}`, so classes used by Streamdown's internal modal
components (like `bg-primary`, `text-primary-foreground`,
`hover:bg-primary/90`) were being purged. The "Open link" button
rendered invisible but remained clickable.

**2. Add same-origin URL whitelist (MEDIUM)**
Configured `linkSafety.onLinkCheck` on the `<Streamdown>` component in
`message.tsx` to whitelist same-origin URLs. Previously, ALL links
(including internal `/api/proxy/...` workspace download URLs) triggered
the "Open external link?" modal. Now same-origin links open directly.

**3. Add Storybook stories (LOW)**
Added `message.stories.tsx` with stories covering default messages, user
messages, internal/external links, the link safety modal, and
conversations.

### Testing
- [ ] Open link safety modal → "Open link" button is visible with proper
styling
- [ ] Click a workspace download link → opens directly (no modal)
- [ ] Click an external link → shows safety modal
- [ ] Verify in both light and dark mode
- [ ] Verify on mobile viewport
- [ ] Storybook stories render correctly

Fixes SECRT-2044
2026-02-26 15:19:54 +00:00
Zamil Majdy
29ca034e40 fix(backend/frontend): error handling, stream reconnection, and chat switching (#12205)
## Problem

CoPilot executions were experiencing:
1. **Duplicate error markers** - Both `execute()` and `_execute_async()`
called `mark_session_completed`, sending duplicate completion markers
2. **RuntimeError bypass** - RuntimeErrors that weren't SDK cleanup
issues bypassed error persistence logic
3. **Generic error messages** - StreamError showed "An error occurred"
instead of actual error text
4. **Empty chat on reconnect** - Messages cleared immediately when
reconnecting, before new messages arrived
5. **Stream not resuming** - Switching chats (A → B → A) didn't resume
active streams due to stale `hasResumedRef`
6. **Excessive diagnostic logging** - 60+ lines of STREAM_DIAG console
logs not needed in production

## Changes 🏗️

### 1. Consolidated Exception Handling
**Files:** `backend/copilot/executor/processor.py`,
`backend/copilot/sdk/service.py`

**processor.py:**
- Removed all error handling from `execute()` method
- Kept error handling only in `_execute_async()` where work happens
- Merged `CancelledError` and `BaseException` handlers into single catch
- Uses `isinstance()` to determine error message

**service.py:**
- Merged `CancelledError` and `Exception` handlers into single catch
- Moved RuntimeError check inside main Exception handler
- Prevents non-cancel-scope RuntimeErrors from bypassing error
persistence

**Impact:** Eliminates duplicate `mark_session_completed` calls, ~70
lines of code removed

---

### 2. Fixed StreamError Message
**File:** `backend/copilot/sdk/service.py`

- Changed from generic `"An error occurred. Please try again."`
- Now shows actual error: `errorText=error_msg`
- Provides real error details to frontend during active stream

---

### 3. Deferred Message Clearing on Reconnect
**File:** `frontend/src/app/(platform)/copilot/useCopilotPage.ts`

- Added `shouldClearOnNextMessageRef` flag
- Set flag when reconnect starts
- Clear old assistant messages only AFTER first new message arrives
- Prevents empty chat flicker during reconnection

---

### 4. Fixed Chat Switching Stream Resume
**File:** `frontend/src/app/(platform)/copilot/useCopilotPage.ts`

**Problem:** When switching Chat A → B → A, the stream didn't resume
because `hasResumedRef.current.get(sessionId)` was still `true`

**Fix:** Clear `hasResumedRef` entry when navigating away from session

**Flow now:**
1. In Chat A with active stream
2. Switch to Chat B → clears `hasResumedRef` for Chat A
3. Switch back to Chat A → `hasResumedRef` is false → resumes stream 

---

### 5. Removed Diagnostic Logging
**Files:** `frontend/useCopilotPage.ts`, `frontend/useChatSession.ts`,
`backend/stream_registry.py`, `backend/processor.py`,
`backend/routes.py`

- Removed all `[STREAM_DIAG]` console logs (60+ lines)
- Logs were useful for debugging but not needed in production
- Cleaner codebase, reduced noise in logs

---

### 6. Exception Handling Order Consistency
**File:** `backend/copilot/executor/processor.py`

- Made both CancelledError and regular exception branches follow same
pattern
- Set `error_msg` before logging in both cases
- Consistent code structure

---

## Architecture Quality: **9/10**

**Strengths:**
- Eliminated duplicate completion markers
- All RuntimeErrors now get proper error persistence
- Real error messages shown to users
- Stream resume works reliably when switching chats
- Cleaner codebase with diagnostic logs removed
- Consistent exception handling patterns

**Trade-offs:**
- Message clearing deferred means brief period with stale + new messages
(acceptable, prevents empty chat)

## Test Plan

- [x] Verify no duplicate completion markers sent
- [x] Trigger RuntimeError, verify error persists
- [x] Check StreamError shows actual error message
- [x] Reconnect, verify chat doesn't go empty
- [x] Switch Chat A → B → A with active stream, verify resume works
- [x] Verify no STREAM_DIAG logs in console
- [x] Run `pnpm format && pnpm lint && pnpm types` - all passed
- [x] Run `poetry run format` - all passed
- [ ] Test in production

## Checklist 📋

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan  
- [x] I have tested my changes according to the test plan
- [x] `.env.default` is updated or compatible (no config changes)
- [x] `docker-compose.yml` is updated or compatible (no config changes)
2026-02-26 13:32:25 +00:00
Reinier van der Leer
1d9dd782a8 feat(backend/api): Add POST /graphs endpoint to external API (#12208)
- Resolves [SECRT-2031: Add upload agent to Library endpoint on external
API](https://linear.app/autogpt/issue/SECRT-2031)

### Changes 🏗️

- Add `POST /graphs` to v1 external API
- Add support for requiring multiple scopes in `require_permission`
middleware
- Add `WRITE_GRAPH` and `WRITE_LIBRARY` API permission scopes

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test `POST /graphs` endpoint through `/docs` Swagger UI
2026-02-26 12:54:39 +01:00
Krzysztof Czerwinski
a1cb3d2a91 feat(blocks): Add Telegram blocks (#12141)
Add Telegram blocks that allow the use of [Telegram bots' API
features](https://core.telegram.org/bots/features).

### Changes 🏗️

1. Credentials & API layer: Bot token auth via `APIKeyCredentials`,
helper functions for JSON API calls (call_telegram_api) and multipart
file uploads (call_telegram_api_with_file)
2. Trigger blocks:
- `TelegramMessageTriggerBlock` — receives messages (text, photo, voice,
audio, document, video, edited message) with configurable event filters
- `TelegramMessageReactionTriggerBlock` — fires on reaction changes
(private chats auto, groups require admin)
2. Action blocks (11 total):
  - Send: Message, Photo, Voice, Audio, Document, Video
  - Reply to Message, Edit Message, Delete Message
  - Get File (download by file_id)
3. Webhook manager: Registers/deregisters webhooks via Telegram's
setWebhook API, validates incoming requests using
X-Telegram-Bot-Api-Secret-Token header
4. Provider registration: Added TELEGRAM to ProviderName enum and
registered `TelegramWebhooksManager`
5. Media send blocks support both URL passthrough (Telegram fetches
directly) and file upload for workspace/data URI inputs

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Non-AI UUIDs
  - [x] Blocks work correctly
    - [x] SendTelegramMessageBlock
    - [x] SendTelegramPhotoBlock
    - [x] SendTelegramVoiceBlock
    - [x] SendTelegramAudioBlock
    - [x] SendTelegramDocumentBlock
    - [x] SendTelegramVideoBlock
    - [x] ReplyToTelegramMessageBlock
    - [x] GetTelegramFileBlock
    - [x] DeleteTelegramMessageBlock
    - [x] EditTelegramMessageBlock
    - [x] TelegramMessageTriggerBlock (works for every trigger type)
    - [x] TelegramMessageReactionTriggerBlock

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 10:25:08 +00:00
Otto
1b91327034 fix(builder): Show X button on edge line hover, not just button hover (#12083)
## Summary

Fixes the issue where the X button for removing connections between
nodes only appears when hovering directly over the button itself. Users
now see the button when hovering anywhere on the connection line.

## Changes

- Added an invisible interaction path along the edge with a 20px stroke
width
- The path triggers the same hover state as the button
- This makes the X button visible when hovering the line OR the button
- Preserves existing behavior for broken edges (always visible)

## Testing

1. Hover over an edge line (not the button) → X button should appear
2. Move from line to button → button should stay visible  
3. Move away from both → button should fade out
4. Broken edges should still show X button always

## Linear

Fixes SECRT-1943

## Screenshots

This is a UX improvement - no visual changes except the button now
appears on line hover.

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

This PR improves the UX for edge deletion by adding an invisible
interaction path with a 20px stroke width that makes the delete button
(X) appear when hovering anywhere along the connection line, not just
when hovering directly over the button.

**Key Changes:**
- Added invisible `<path>` element before `BaseEdge` with
`stroke="transparent"` and `strokeWidth={20}`
- Path has `onMouseEnter` and `onMouseLeave` handlers that trigger the
same `setIsHovered` state used by the delete button
- Delete button visibility logic remains unchanged: fades in when
`isHovered` is true (or always visible for broken edges)
- Works uniformly for all edge types (regular, static, and broken edges)

**How It Works:**
The invisible path creates a wider hit area (20px) around the edge
curve, making it much easier for users to trigger the hover state. When
the mouse enters this area, `isHovered` becomes true, which causes the
delete button to fade in (via the existing opacity transition logic).
The button itself also has hover handlers, so moving from the line to
the button maintains the visible state smoothly.
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge with minimal risk - it's a small, focused UX
improvement with no logic changes
- The implementation is clean and focused: adds only 9 lines of code,
uses existing state management (`isHovered`), and doesn't modify any
deletion logic. The invisible path is a standard SVG/React pattern for
expanding hit areas, and the approach is consistent with how the delete
button already handles hover events. No breaking changes, no side
effects.
- No files require special attention
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->

---------

Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
2026-02-26 10:02:01 +00:00
Krzysztof Czerwinski
c7cdb40c5b feat(platform): Update new builder search (#11806)
### Changes 🏗️

- Add materialized view for suggested blocks
- Make `search` in builder accept comma separated filter list in query
- Remove Otto suggestions
- Use hybrid search for blocks search in builder
- Exclude `AgentExecutorBlocks` from builder
- Remove `Block` suffix from builder items

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Materialized view function works (when running manually)
- [x] Higher execution count blocks are shown first in "suggested
blocks" (uses materialized view)
  - [x] Hybrid search works
- [x] `AgentExecutorBlocks` doesn't appear on search results and in
blocks list
  - [x] `Block` suffix isn't displayed on blocks names in builder items

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-26 09:56:40 +00:00
Nicholas Tindle
77fb4419d0 Handle workspace:// URLs in regular markdown links (#12166)
### Changes 🏗️

Extended the `resolveWorkspaceUrls` function to handle both image syntax
(`![alt](workspace://id#mime)`) and regular link syntax
(`[text](workspace://id)`).

Previously, only image links were being resolved. Regular workspace
links were being blocked by Streamdown's rehype-harden sanitizer because
`workspace://` is not in the allowed URL-scheme whitelist, causing
"[blocked]" to appear next to link text.

The fix:
- Refactored the function to process image links first (existing
behavior)
- Added a second regex replacement to handle regular links using a
negative lookbehind (`(?<!!)`) to avoid matching image syntax
- Both patterns now resolve `workspace://` URLs to proxy download URLs
via `/api/proxy`
- Updated JSDoc comments to clarify the dual handling

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified image links with MIME type hints still resolve correctly
- [x] Verified regular workspace links now resolve to proxy URLs instead
of being blocked
- [x] Confirmed negative lookbehind prevents double-processing of image
syntax

https://claude.ai/code/session_0184TVJJcEoB8wbX9htCnv4b

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Low Risk**
> Low risk: a small, localized frontend markdown preprocessing change
that only rewrites `workspace://` URLs to existing `/api/proxy` download
URLs; main risk is regex edge cases affecting link rendering.
> 
> **Overview**
> Updates `resolveWorkspaceUrls` in `ChatMessagesContainer` to rewrite
**both** `workspace://` image markdown and regular markdown links into
`/api/proxy` download URLs so Streamdown sanitization no longer shows
`[blocked]` for workspace links.
> 
> Image handling is preserved (including `#video/*` MIME hints via
`video:` alt text), and a second regex pass with a negative lookbehind
avoids double-processing image syntax when rewriting plain links.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
e17749b72c. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-02-25 12:33:10 +00:00
Bently
9f002ce8f6 fix(frontend): improve UX for expired or duplicate password reset links (#12123)
## Summary
Improves the user experience when a password reset link has expired or
been used, replacing the confusing generic error with a clean, helpful
message.

## Changes
- Added `ExpiredLinkMessage` component that displays a user-friendly
error state
- Updated reset password page to detect expired/used links from:
- Supabase error format
(`error=access_denied&error_code=otp_expired&error_description=...`)
  - Internal clean format (`error=link_expired`)
- Enhanced callback route to detect and map expired/invalid link errors
- Clear, actionable UI with:
  - Friendly error message explaining what happened
  - "Send Me a New Link" button to request a new reset email
  - Login link for users who already have access

## Before
Users saw a confusing URL with error parameters and an unclear form:
```
/reset-password?error=access_denied&error_code=otp_expired&error_description=Email+link+is+invalid+or+has+expired
```

## After
Users see a clean, helpful message explaining the issue and how to fix
it.

<img width="548" height="454" alt="image"
src="https://github.com/user-attachments/assets/e867e522-146c-4d43-91b3-9e62d2957f95"
/>


Closes SECRT-1369

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [ ] Navigate to `/reset-password?error=link_expired` and verify the
ExpiredLinkMessage component appears
  - [ ] Click "Send Me a New Link" and verify the email form appears
- [ ] Navigate to
`/reset-password?error=access_denied&error_code=otp_expired` and verify
same behavior

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

Improved password reset UX by adding an `ExpiredLinkMessage` component
that displays when users follow expired or already-used reset links. The
implementation detects expired link errors from Supabase
(`error_code=otp_expired`) and internal format (`error=link_expired`),
replacing confusing URL parameters with a clean message.

**Key changes:**
- Added error detection logic in both callback route and reset password
page to identify expired/invalid links
- Created new `ExpiredLinkMessage` component with friendly messaging
- Enhanced error handling to differentiate between expired links and
other errors

**Issues found:**
- The "Send Me a New Link" button misleadingly suggests it will send an
email, but it only reveals the email form - user must still enter email
and submit
- `access_denied` error detection may be too broad and could incorrectly
classify non-expired errors as expired links
</details>


<details><summary><h3>Confidence Score: 3/5</h3></summary>

- This PR improves UX but has logic issues that could mislead users
- The implementation correctly detects expired links and displays
helpful UI, but the "Send Me a New Link" button doesn't actually send an
email (just shows the form), which creates a misleading user experience.
Additionally, the `access_denied` error check is overly broad and could
incorrectly classify errors. These are functional issues that should be
addressed before merge.
- Pay close attention to `page.tsx` - the `handleSendNewLink` function
and error detection logic need refinement
</details>


<details><summary><h3>Flowchart</h3></summary>

```mermaid
flowchart TD
    Start[User clicks reset link with code] --> Callback[API: /auth/callback/reset-password]
    Callback --> CheckCode{Code valid?}
    
    CheckCode -->|No - expired/invalid/used| DetectError[Detect error type]
    DetectError --> CheckExpired{Contains expired/<br/>invalid/otp_expired/<br/>already/used?}
    CheckExpired -->|Yes| RedirectExpired[Redirect to /reset-password?error=link_expired]
    CheckExpired -->|No| RedirectOther[Redirect to /reset-password?error=message]
    
    CheckCode -->|Yes| RedirectSuccess[Redirect to /reset-password]
    
    RedirectExpired --> PageLoad[Page: /reset-password]
    RedirectOther --> PageLoad
    RedirectSuccess --> PageLoad
    
    PageLoad --> ParseParams[Parse URL params]
    ParseParams --> CheckErrorParams{Has error or<br/>error_code?}
    
    CheckErrorParams -->|Yes| CheckExpiredParams{error=link_expired OR<br/>errorCode=otp_expired OR<br/>error=access_denied OR<br/>description contains<br/>expired/invalid?}
    CheckExpiredParams -->|Yes| ShowExpired[Show ExpiredLinkMessage]
    CheckExpiredParams -->|No| ShowToast[Show error toast]
    
    CheckErrorParams -->|No| CheckUser{User<br/>authenticated?}
    
    ShowExpired --> ClickButton[User clicks 'Send Me a New Link']
    ClickButton --> HideExpired[setShowExpiredMessage false]
    HideExpired --> ShowForm[Show email form]
    
    ShowToast --> ClearParams[Clear error params from URL]
    ClearParams --> CheckUser
    
    CheckUser -->|Yes| ShowPasswordForm[Show password change form]
    CheckUser -->|No| ShowForm[Show email form]
```
</details>


<sub>Last reviewed commit: 80e9f40</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-25 12:11:55 +00:00
Ubbe
74691076c6 fix(frontend/copilot): show clarification and agent-saved cards without accordion (#12204)
### Background

The CoPilot tool UI wraps several output cards (clarification questions,
agent saved confirmation) inside a collapsible `ToolAccordion`. This
means users have to expand the accordion to see important interactive
content — clarification questions they need to answer, or confirmation
that their agent was created/updated.

### Changes 🏗️

- **Clarification questions always visible**: Moved
`ClarificationQuestionsCard` out of the `ToolAccordion` in both
`CreateAgent` and `EditAgent` tools so users immediately see and can
answer questions without expanding an accordion
- **Agent saved card always visible**: Moved the agent-saved
confirmation card out of the `ToolAccordion` in both tools so the
success state with library/builder links is immediately visible
- **Extracted `AgentSavedCard` component**: The agent-saved card was
duplicated between `CreateAgent` and `EditAgent` — extracted it into a
shared `copilot/components/AgentSavedCard/AgentSavedCard.tsx` component,
parameterized by `message` ("has been saved to your library!" vs "has
been updated!")
- **ClarificationQuestionsCard polish**: Updated spacing, icon
(`ChatTeardropDotsIcon`), typography variants, border styles, and number
badge sizing for a cleaner look
- **Minor atom tweaks**: Lightened `secondary` button variant
(`zinc-200` → `zinc-100`), changed textarea border radius from
`rounded-3xl` to `rounded-xl`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm format` passes
  - [x] `pnpm lint` passes
  - [x] `pnpm types` passes
- [ ] Create an agent via CoPilot and verify the saved card shows
without accordion
- [ ] Trigger clarification questions and verify they show without
accordion
- [ ] Edit an agent via CoPilot and verify the updated card shows
without accordion
- [ ] Verify the ClarificationQuestionsCard styling looks correct
(spacing, icons, borders)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 18:13:01 +07:00
Otto
b15ad0df9b hotfix(frontend): fix null credits TypeError on /copilot (#12202)
Requested by @majdyz

Fix `TypeError: Cannot read properties of null (reading 'credits')` on
the /copilot page.

**Sentry:**
[BUILDER-71P](https://significant-gravitas.sentry.io/issues/7256025912/)
**Linear:** SENTRY-1110

## Root Cause

Two issues combined:

1. **`getUserCredit()` had a broken try/catch** — it wasn't `await`ing
`_get()`, so async errors (including null responses) were never caught
2. **`_makeClientRequest` returns `null` during logout** — when a user
is logging out and `/credits` races with auth teardown, the response is
`null`

Chain: logout starts → `/credits` fetch races → auth error →
`_makeClientRequest` returns `null` → `getUserCredit` passes `null`
through → `fetchCredits` does `null.credits` → 💥

## Fix

- `getUserCredit()`: Add `await` + null coalescing fallback to `{
credits: 0 }`
- `fetchCredits()`: Add optional chaining guard (`response?.credits ??
null`)
2026-02-25 10:38:08 +00:00
Abhimanyu Yadav
2136defea8 feat(library): implement folder organization system for agents (#12101)
### Changes 🏗️

This PR adds folder organization capabilities to the library, allowing
users to organize their agents into folders:

- Added new `LibraryFolder` model and database schema
- Created folder management API endpoints for CRUD operations
- Implemented folder tree structure with proper parent-child
relationships
- Added drag-and-drop functionality for moving agents between folders
- Created folder creation dialog with emoji picker for folder icons
- Added folder editing and deletion capabilities
- Implemented folder navigation in the library UI
- Added validation to prevent circular references and excessive nesting
- Created animation for favoriting agents
- Updated library agent list to show folder structure
- Added folder filtering to agent list queries

<img width="1512" height="950" alt="Screenshot 2026-02-13 at 9 08 45 PM"
src="https://github.com/user-attachments/assets/78778e03-4349-4d50-ad71-d83028ca004a"
/>

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Create a new folder with custom name, icon, and color
  - [x] Move agents into folders via drag and drop
  - [x] Move agents into folders via context menu
  - [x] Navigate between folders
  - [x] Edit folder properties (name, icon, color)
  - [x] Delete folders and verify agents return to root
  - [x] Verify favorite animation works when adding to favorites
  - [x] Test folder navigation with search functionality
  - [x] Verify folder tree structure is maintained

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

This PR implements a comprehensive folder organization system for
library agents, enabling hierarchical structure up to 5 levels deep.

**Backend Changes:**
- Added `LibraryFolder` model with self-referential hierarchy
(`parentId` → `Parent`/`Children`)
- Implemented CRUD operations with validation for circular references
and depth limits (MAX_FOLDER_DEPTH=5)
- Added `folderId` foreign key to `LibraryAgent` table
- Created folder management endpoints: list, get, create, update, move,
delete, and bulk agent moves
- Proper soft-delete cascade handling for folders and their contained
agents

**Frontend Changes:**
- Created folder creation/edit/delete dialogs with emoji picker
integration
- Implemented folder navigation UI with breadcrumbs and folder tree
structure
- Added drag-and-drop support for moving agents between folders
- Created context menu for agent actions (move to folder, remove from
folder)
- Added favorite animation system with `FavoriteAnimationProvider`
- Integrated folder filtering into agent list queries

**Key Features:**
- Folders support custom names, emoji icons, and hex colors
- Unique constraint per parent folder per user prevents duplicate names
- Validation prevents circular folder hierarchies and excessive nesting
- Agents can be moved between folders via drag-drop or context menu
- Deleting a folder soft-deletes all descendant folders and contained
agents
</details>


<details><summary><h3>Confidence Score: 4/5</h3></summary>

- This PR is safe to merge with minor considerations for performance
optimization
- The implementation is well-structured with proper validation, error
handling, and database constraints. The folder hierarchy logic correctly
prevents circular references and enforces depth limits. However, there
are some performance concerns with N+1 queries in depth calculation and
circular reference checking that could be optimized for deeply nested
hierarchies. The foreign key constraint (ON DELETE RESTRICT) conflicts
with the hard-delete code path but shouldn't cause issues since
soft-deletes are the default. The client-side duplicate validation is
redundant but not harmful.
- Pay close attention to migration file (foreign key constraint) and
db.py (performance of recursive queries)
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant Frontend
    participant API
    participant DB

    User->>Frontend: Create folder with name/icon/color
    Frontend->>API: POST /v2/folders
    API->>DB: Validate parent exists & depth limit
    API->>DB: Check unique constraint (userId, parentId, name)
    DB-->>API: Folder created
    API-->>Frontend: LibraryFolder response
    Frontend-->>User: Show success toast

    User->>Frontend: Drag agent to folder
    Frontend->>API: POST /v2/folders/agents/bulk-move
    API->>DB: Verify folder exists
    API->>DB: Update LibraryAgent.folderId
    DB-->>API: Agents updated
    API-->>Frontend: Updated agents
    Frontend-->>User: Refresh agent list

    User->>Frontend: Navigate into folder
    Frontend->>API: GET /v2/library/agents?folder_id=X
    API->>DB: Query agents WHERE folderId=X
    DB-->>API: Filtered agents
    API-->>Frontend: Agent list
    Frontend-->>User: Display folder contents

    User->>Frontend: Delete folder
    Frontend->>API: DELETE /v2/folders/{id}
    API->>DB: Get descendant folders recursively
    API->>DB: Soft-delete folders + agents in transaction
    DB-->>API: Deletion complete
    API-->>Frontend: 204 No Content
    Frontend-->>User: Show success toast
```
</details>


<sub>Last reviewed commit: a6c2f64</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-24 15:04:56 +00:00
Zamil Majdy
6e61cb103c fix(copilot): workspace file listing fix (#12190)
Requested by @majdyz

Improves workspace file display in GenericTool:
- Base64 content decoding for workspace files
- Rich file object rendering (path, size, mime type)
- MCP text extraction from SDK tool responses (Read, Glob, Grep, Edit)
- Better file list formatting for both string and object file entries

---------

Co-authored-by: Otto (AGPT) <otto@agpt.co>
2026-02-24 12:33:24 +00:00
Zamil Majdy
0e72e1f5e7 fix(platform/copilot): fix stuck sessions, stop button, and StreamFinish reliability (#12191)
## Summary

- **Fix stuck sessions**: Root cause was `_stream_listener` infinite
xread loop when Redis session metadata TTL expired — `hget` returned
`None` which bypassed the `status != "running"` break condition. Fixed
by treating `None` status as non-running.
- **Fix stop button reliability**: Cancel endpoint now force-completes
via `mark_session_completed` when executor doesn't respond within 5s.
Returns `cancelled=True` for already-expired sessions.
- **Single-owner StreamFinish**: All `yield StreamFinish()` removed from
service layers (sdk/service.py, service.py, dummy.py).
`mark_session_completed` is now the single atomic source of truth for
publishing StreamFinish via Lua CAS script.
- **Rename task → session/turn**: Consistent terminology across
stream_registry and processor.
- **Frontend session refetch**: Added `refetchOnMount: true` so page
refresh re-fetches session state.
- **Test fixes**: Updated e2e, service, and run_agent tests for
StreamFinish removal; fixed async fixture decorators.

## Test plan
- [x] E2E dummy streaming tests pass (13 passed, 1 xfailed)
- [x] run_agent_test.py passes (async fixture decorator fix)
- [x] service_test.py passes (StreamFinish assertions removed)
- [ ] Manual: verify stuck sessions recover on page refresh
- [ ] Manual: verify stop button works for active and expired sessions
- [ ] Manual: verify no duplicate StreamFinish events in SSE stream
2026-02-24 10:49:22 +00:00
Swifty
163b0b3c9d feat(backend): pre-populate CoPilotUnderstanding from Tally form on signup (#12119)
When new users sign up, check if they previously filled out the Tally
beta application form and, if so, pre-populate their
CoPilotUnderstanding with business data extracted from that form. This
gives the CoPilot (Otto) immediate context about the user on their very
first chat interaction.

### Changes 🏗️

- **`backend/util/settings.py`**: Added `tally_api_key` to `Secrets`
class
- **`backend/.env.default`**: Added `TALLY_API_KEY=` env var entry
- **`backend/data/tally.py`** (new): Core Tally integration module
- Redis-cached email index of form submissions (1h TTL) with incremental
refresh via `startDate`
  - Paginated Tally API fetching with Bearer token auth
  - Email matching (case-insensitive) against submission data
- LLM extraction (gpt-4o-mini via OpenRouter) of
`BusinessUnderstandingInput` fields
  - Fire-and-forget orchestrator that is idempotent and never raises
- **`backend/api/features/v1.py`**: Added background task in
`get_or_create_user_route` to trigger Tally lookup on login (skips if
understanding already exists)
- **`backend/data/tally_test.py`** (new): 15 unit tests covering index
building, email case-insensitivity, cache hit/miss, format helpers,
idempotency, graceful degradation, and error resilience

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] All 15 unit tests pass (`poetry run pytest
backend/data/tally_test.py --noconftest -xvs`)
  - [x] Lint clean (`poetry run ruff check` on changed files)
  - [x] Type check clean (`poetry run pyright` on new files)
- [ ] Manual: Set `TALLY_API_KEY` in `.env`, create a new user, verify
CoPilotUnderstanding is populated
- [ ] Manual: Verify user creation succeeds when Tally API key is
missing or API is down

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
- Added `TALLY_API_KEY=` to `.env.default` (optional, empty by default —
feature is a no-op without it)

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

This PR adds a Tally form integration that pre-populates
`CoPilotUnderstanding` for new users by matching their signup email
against cached Tally beta application form submissions, then using an
LLM (gpt-4o-mini via OpenRouter) to extract structured business data.

- **New module `tally.py`** implements Redis-cached email indexing of
Tally form submissions with incremental refresh, email matching, LLM
extraction, and an idempotent fire-and-forget orchestrator
- **`v1.py`** adds a background task on the `get_or_create_user_route`
to trigger Tally lookup on every login (idempotency check is inside the
called function)
- **`settings.py` / `.env.default`** adds `tally_api_key` as an optional
secret — feature is a no-op without it
- **`tally_test.py`** adds 15 unit tests with thorough mocking coverage
- **Bug: TTL mismatch** — `_LAST_FETCH_TTL` (2h) > `_INDEX_TTL` (1h)
creates a window where incremental refresh loses all previously indexed
emails because the base index has expired but `last_fetch` persists.
This will cause silent data loss for users whose form submissions were
indexed before the cache expiry
- **Bug: `str.format()` on LLM prompt** — form data containing `{` or
`}` will crash the prompt formatting, silently preventing understanding
population for those users
</details>


<details><summary><h3>Confidence Score: 2/5</h3></summary>

- This PR has two logic bugs that will cause silent data loss in
production — recommend fixing before merge.
- The TTL mismatch between `_LAST_FETCH_TTL` and `_INDEX_TTL` will
intermittently cause incomplete caches, silently dropping users from the
email index. The `str.format()` issue will cause failures for any form
submission containing curly braces. Both bugs are caught by the
top-level exception handler, so they won't crash the service, but they
will silently prevent the feature from working correctly for affected
users. The overall architecture is sound and well-tested for normal
paths.
- `autogpt_platform/backend/backend/data/tally.py` — contains both the
TTL mismatch bug in `_refresh_cache` and the `str.format()` issue in
`extract_business_understanding`
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant API as v1.py (get_or_create_user_route)
    participant Tally as tally.py (populate_understanding_from_tally)
    participant DB as Database (understanding)
    participant Redis
    participant TallyAPI as Tally API
    participant LLM as OpenRouter (gpt-4o-mini)

    User->>API: POST /auth/user (JWT)
    API->>API: get_or_create_user(user_data)
    API-->>User: Return user (immediate)
    API->>Tally: asyncio.create_task(populate_understanding_from_tally)

    Tally->>DB: get_business_understanding(user_id)
    alt Understanding exists
        DB-->>Tally: existing understanding
        Note over Tally: Skip (idempotent)
    else No understanding
        DB-->>Tally: None
        Tally->>Tally: Check tally_api_key configured
        Tally->>Redis: Check cached email index
        alt Cache hit
            Redis-->>Tally: email_index + questions
        else Cache miss
            Redis-->>Tally: None
            Tally->>TallyAPI: GET /forms/{id}/submissions (paginated)
            TallyAPI-->>Tally: submissions + questions
            Tally->>Tally: Build email index
            Tally->>Redis: Cache index (1h TTL)
        end
        Tally->>Tally: Lookup email in index
        alt Email found
            Tally->>Tally: format_submission_for_llm()
            Tally->>LLM: Extract BusinessUnderstandingInput
            LLM-->>Tally: JSON structured data
            Tally->>DB: upsert_business_understanding(user_id, input)
        end
    end
```
</details>


<sub>Last reviewed commit: 92d2da4</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Otto (AGPT) <otto@agpt.co>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-24 11:31:29 +01:00
Bently
ef42b17e3b docs: add Podman compatibility warning (#12120)
## Summary
Adds a warning to the Getting Started docs clarifying that **Podman and
podman-compose are not supported**.

## Problem
Users on Windows using `podman-compose` instead of Docker get errors
like:
```
Error: the specified Containerfile or Dockerfile does not exist, ..\..\autogpt_platform\backend\Dockerfile
```

This is because Podman handles relative paths differently than Docker,
causing incorrect path resolution on Windows.

## Solution
- Added a clear warning section after the Windows WSL 2 notes
- Explains the error users might see
- Directs them to install Docker Desktop instead

Closes #11358

<!-- greptile_comment -->

<details><summary><h3>Greptile Summary</h3></summary>

Adds a "Podman Not Supported" warning section to the Getting Started
documentation, placed after the Windows/WSL 2 installation notes. The
section clarifies that Docker is required, shows the typical error
message users encounter when using Podman, and directs them to install
Docker Desktop instead. This addresses issue #11358 where Windows users
using `podman-compose` hit path resolution errors.

- Adds `### ⚠️ Podman Not Supported` section under Manual Setup, after
Windows Installation Note
- Includes the specific error message users see with Podman for easy
identification
- Links to Docker Desktop installation docs as the recommended solution
- Formatting is consistent with existing sections in the document (emoji
headings, code blocks for errors)
</details>


<details><summary><h3>Confidence Score: 5/5</h3></summary>

- This PR is safe to merge — it only adds a documentation warning
section with no code changes.
- The change is a small, well-written documentation addition that adds a
Podman compatibility warning. It touches only one markdown file,
introduces no code changes, and is consistent with the existing document
structure and style. No issues were found.
- No files require special attention.
</details>


<details><summary><h3>Flowchart</h3></summary>

```mermaid
flowchart TD
    A[User wants to run AutoGPT] --> B{Which container runtime?}
    B -->|Docker / Docker Desktop| C[docker compose up -d --build]
    C --> D[AutoGPT starts successfully]
    B -->|Podman / podman-compose| E[podman-compose up -d --build]
    E --> F[Error: Containerfile or Dockerfile does not exist]
    F --> G[New warning section directs user to install Docker Desktop]
    G --> C
```
</details>


<sub>Last reviewed commit: 23ea6bd</sub>

<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-23 15:19:24 +00:00
Ubbe
a18ffd0b21 fix(frontend/copilot): always-visible credentials, inputs, and login prompts (#12194)
Credentials, inputs, and login prompts in copilot tool outputs were
hidden inside collapsible accordions — users could accidentally collapse
them, hiding blocking actionable UI. This PR extracts all blocking
requirements out of accordions so they're always visible.

### Changes 🏗️

- **RunAgent & RunBlock**: Extract `SetupRequirementsCard` (credentials
picker) out of `ToolAccordion` — renders standalone, always visible
- **RunAgent**: Also extract `AgentDetailsCard` (inputs needed) and
`need_login` message out of accordion
- **SetupRequirementsCard (RunBlock)**: Input form always visible
(removed toggle button and animation), unified "Proceed" button disabled
until credentials + inputs are satisfied
- **SetupRequirementsCard (RunAgent)**: "Proceed" button disabled until
all credentials are selected
- **Both cards**: Added titled box with border for credentials section
("Block credentials" / "Agent credentials"), matching the existing
inputs box pattern
- **CredentialsFlatView**: "Add" button uses `variant="primary"` when
user has no credentials (was `secondary`)
- **Styleguide**: Added mock `CredentialsProvidersContext` with two
scenarios:
  - No credentials → shows "add new" flow
  - Has credentials → shows selection list with existing accounts
- **CreateAgent & EditAgent**: Picked up user-initiated styling
refinements

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm format && pnpm lint && pnpm types` all pass
  - [ ] Visit `/copilot/styleguide` and verify:
- [ ] "Setup requirements — no credentials" shows add-credential button
(primary variant)
- [ ] "Setup requirements — has credentials" shows credential selection
dropdown
- [ ] Both RunAgent and RunBlock setup requirements render outside
accordion
- [ ] Trigger a copilot agent run that requires credentials — credential
picker always visible
- [ ] Trigger a copilot block run that requires credentials + inputs —
both sections visible, "Proceed" disabled until ready
- [ ] Trigger a copilot agent run that returns "agent details" — card
renders outside accordion
- [ ] Verify other output types (execution_started, error) still render
inside accordions


🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 16:39:21 +07:00
Otto
e40c8c70ce fix(copilot): collision detection, session locking, and sync for concurrent message saves (#12177)
Requested by @majdyz

Concurrent writers (incremental streaming saves from PR #12173 and
long-running tool callbacks) can race to persist messages with the same
`(sessionId, sequence)` pair, causing unique constraint violations on
`ChatMessage`.

**Root cause:** The streaming loop tracks `saved_msg_count` in-memory,
but the long-running tool callback (`_build_long_running_callback`) also
appends messages and calls `upsert_chat_session` independently — without
coordinating sequence numbers. When the streaming loop does its next
incremental save with the stale `saved_msg_count`, it tries to insert at
a sequence that already exists.

**Fix:** Multi-layered defense-in-depth approach:

1. **Collision detection with retry** (db.py): `add_chat_messages_batch`
uses `create_many()` in a transaction. On `UniqueViolationError`,
queries `MAX(sequence)+1` from DB and retries with the correct offset
(max 5 attempts).

2. **Robust sequence tracking** (db.py): `get_next_sequence()` uses
indexed `find_first` with `order={"sequence": "desc"}` for O(1) MAX
lookup, immune to deleted messages.

3. **Session-based counter** (model.py): Added `saved_message_count`
field to `ChatSession`. `upsert_chat_session` returns the session with
updated count, eliminating tuple returns throughout the codebase.

4. **MessageCounter dataclass** (sdk/service.py): Replaced list[int]
mutable reference pattern with a clean `MessageCounter` dataclass for
shared state between streaming loop and long-running callbacks.

5. **Session locking** (sdk/service.py): Prevent concurrent streams on
the same session using Redis `SET NX EX` distributed locks with TTL
refresh on heartbeats (config.stream_ttl = 3600s).

6. **Atomic operations** (db.py): Single timestamp for all messages and
session update in batch operations for consistency. Parallel queries
with `asyncio.gather` for lower latency.

7. **Config-based TTL** (sdk/service.py, config.py): Consolidated all
TTL constants to use `config.stream_ttl` (3600s) with lock refresh on
heartbeats.

### Key implementation details

- **create_many**: Uses `sessionId` directly (not nested
`Session.connect`) as `create_many` doesn't support nested creates
- **Type narrowing**: Added explicit `assert session is not None`
statements for pyright type checking in async contexts
- **Parallel operations**: Use `asyncio.gather` for independent DB
operations (create_many + session update)
- **Single timestamp**: All messages in a batch share the same
`createdAt` timestamp for atomicity

### Changes
- `backend/copilot/db.py`: Collision detection with `create_many` +
retry, indexed sequence lookup, single timestamp, parallel queries
- `backend/copilot/model.py`: Added `saved_message_count` field,
simplified return types
- `backend/copilot/sdk/service.py`: MessageCounter dataclass, session
locking with refresh, config-based TTL, type narrowing
- `backend/copilot/service.py`: Updated all callers to handle new return
types
- `backend/copilot/config.py`: Increased long_running_operation_ttl to
3600s with clarified docstring
- `backend/copilot/*_test.py`: Tests updated for new signatures

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-02-20 15:05:03 +00:00
173 changed files with 12504 additions and 6224 deletions

View File

@@ -190,5 +190,8 @@ ZEROBOUNCE_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Tally Form Integration (pre-populate business understanding on signup)
TALLY_API_KEY=
# Other Services
AUTOMOD_API_KEY=

View File

@@ -88,20 +88,23 @@ async def require_auth(
)
def require_permission(permission: APIKeyPermission):
def require_permission(*permissions: APIKeyPermission):
"""
Dependency function for checking specific permissions
Dependency function for checking required permissions.
All listed permissions must be present.
(works with API keys and OAuth tokens)
"""
async def check_permission(
async def check_permissions(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
if permission not in auth.scopes:
missing = [p for p in permissions if p not in auth.scopes]
if missing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission.value}",
detail=f"Missing required permission(s): "
f"{', '.join(p.value for p in missing)}",
)
return auth
return check_permission
return check_permissions

View File

@@ -18,6 +18,7 @@ from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Settings
from .integrations import integrations_router
@@ -95,6 +96,43 @@ async def execute_graph_block(
return output
@v1_router.post(
path="/graphs",
tags=["graphs"],
status_code=201,
dependencies=[
Security(
require_permission(
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
)
)
],
)
async def create_graph(
graph: graph_db.Graph,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
),
) -> graph_db.GraphModel:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID.
It is automatically added to the user's library.
"""
from backend.api.features.library import db as library_db
graph_model = graph_db.make_graph_model(graph, auth.user_id)
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph_model.validate_graph(for_run=False)
await graph_db.create_graph(graph_model, user_id=auth.user_id)
await library_db.create_library_agent(graph_model, auth.user_id)
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
return activated_graph
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],

View File

@@ -1,15 +1,17 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Sequence
from typing import Any, Sequence, get_args, get_origin
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
@@ -19,7 +21,6 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -42,6 +43,16 @@ MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
# Boost blocks over marketplace agents in search results
BLOCK_SCORE_BOOST = 50.0
# Block IDs to exclude from search results
EXCLUDED_BLOCK_IDS = frozenset(
{
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
}
)
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@@ -64,8 +75,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled:
# Skip disabled and excluded blocks
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
@@ -116,6 +127,9 @@ def get_blocks(
# Skip disabled blocks
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
@@ -255,14 +269,25 @@ async def _build_cached_search_results(
"my_agents": 0,
}
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
elif include_blocks or include_integrations:
# No query - list all blocks using in-memory approach
block_results, block_total, integration_total = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
@@ -307,10 +332,14 @@ async def _build_cached_search_results(
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Collect all blocks for listing (no search query).
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
@@ -323,6 +352,10 @@ def _collect_block_results(
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
@@ -332,10 +365,6 @@ def _collect_block_results(
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
@@ -346,8 +375,122 @@ def _collect_block_results(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=score,
sort_key=_get_item_name(block_info),
score=BLOCK_SCORE_BOOST,
sort_key=block_info.name.lower(),
)
)
return results, block_count, integration_count
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using hybrid search with builder-specific filtering.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
normalized_query = query.strip().lower()
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for result in search_results:
block_id = result["content_id"]
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
)
)
@@ -472,6 +615,8 @@ async def _get_static_counts():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
all_blocks += 1
@@ -498,47 +643,25 @@ async def _get_static_counts():
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
if _contains_type(field.annotation, LlmModel):
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
@@ -645,31 +768,20 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600)
@cached(ttl_seconds=3600, shared_cache=True)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
"""Return the most-executed blocks from the last 14 days.
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
and returns the top `count` blocks sorted by execution count, excluding
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
"""
results = await mv_suggested_blocks.prisma().find_many()
# Get the top blocks based on execution count
# But ignore Input and Output blocks
# But ignore Input, Output, Agent, and excluded blocks
blocks: list[tuple[BlockInfo, int]] = []
execution_counts = {row.block_id: row.execution_count for row in results}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
@@ -679,11 +791,9 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
BlockType.AGENT,
):
continue
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
if block.id in EXCLUDED_BLOCK_IDS:
continue
execution_count = execution_counts.get(block.id, 0)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)

View File

@@ -27,7 +27,6 @@ class SearchEntry(BaseModel):
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
providers: list[ProviderName]
top_blocks: list[BlockInfo]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Sequence
from typing import Annotated, Sequence, cast, get_args
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
@@ -10,6 +10,8 @@ from backend.util.models import Pagination
from . import db as builder_db
from . import model as builder_model
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
@@ -49,11 +51,6 @@ async def get_suggestions(
Get all suggestions for the Blocks Menu.
"""
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
providers=[
ProviderName.TWITTER,
@@ -151,7 +148,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
filter: Annotated[str | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -160,9 +157,20 @@ async def search(
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# If no filters are provided, then we will return all types
if not filter:
filter = [
# Parse and validate filter parameter
filters: list[builder_model.FilterType]
if filter:
filter_values = [f.strip() for f in filter.split(",")]
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
if invalid_filters:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
)
filters = cast(list[builder_model.FilterType], filter_values)
else:
filters = [
"blocks",
"integrations",
"marketplace_agents",
@@ -174,7 +182,7 @@ async def search(
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filter,
filters=filters,
by_creator=by_creator,
)
@@ -196,7 +204,7 @@ async def search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filter,
filter=filters,
by_creator=by_creator,
search_id=search_id,
),

View File

@@ -2,23 +2,19 @@
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -46,9 +42,6 @@ from backend.copilot.tools.models import (
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
@@ -99,10 +92,8 @@ class CreateSessionResponse(BaseModel):
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
turn_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
@@ -132,22 +123,13 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelTaskResponse(BaseModel):
"""Response model for the cancel task endpoint."""
class CancelSessionResponse(BaseModel):
"""Response model for the cancel session endpoint."""
cancelled: bool
task_id: str | None = None
reason: str | None = None
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -270,7 +252,7 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
If there's an active stream for this session, returns active_stream info for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
@@ -288,28 +270,21 @@ async def get_session(
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
turn_id=active_session.turn_id,
last_message_id=last_message_id,
)
return SessionDetailResponse(
@@ -329,7 +304,7 @@ async def get_session(
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelTaskResponse:
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
@@ -338,39 +313,33 @@ async def cancel_session_task(
"""
await _validate_and_get_session(session_id, user_id)
active_task, _ = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
return CancelSessionResponse(cancelled=True, reason="no_active_session")
task_id = active_task.task_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
f"session ...{session_id[-8:]}"
)
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
# Poll until the executor confirms the task is no longer running.
# Keep max_wait below typical reverse-proxy read timeouts.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
task = await stream_registry.get_task(task_id)
if task is None or task.status != "running":
session_state = await stream_registry.get_session(session_id)
if session_state is None or session_state.status != "running":
logger.info(
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
)
return CancelTaskResponse(cancelled=True, task_id=task_id)
return CancelSessionResponse(cancelled=True)
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
return CancelTaskResponse(
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
)
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
return CancelSessionResponse(cancelled=True)
@router.post(
@@ -390,16 +359,15 @@ async def stream_chat_post(
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
StreamingResponse: SSE-formatted response chunks.
"""
import asyncio
@@ -446,35 +414,35 @@ async def stream_chat_post(
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_create_start = time.perf_counter()
await stream_registry.create_session(
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_call_id="chat_stream",
tool_name="chat",
operation_id=operation_id,
turn_id=turn_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
}
},
)
await enqueue_copilot_task(
task_id=task_id,
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
)
@@ -491,7 +459,7 @@ async def stream_chat_post(
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
@@ -499,11 +467,12 @@ async def stream_chat_post(
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
last_message_id=subscribe_from_id,
)
if subscriber_queue is None:
@@ -586,19 +555,19 @@ async def stream_chat_post(
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
f"Error unsubscribing from session {session_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
@@ -645,17 +614,21 @@ async def resume_session_stream(
"""
import asyncio
active_task, _last_id = await stream_registry.get_active_task_for_session(
active_session, last_message_id = await stream_registry.get_active_session(
session_id, user_id
)
if not active_task:
if not active_session:
return Response(status_code=204)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
user_id=user_id,
last_message_id="0-0", # Full replay so useChat rebuilds the message
last_message_id="0-0",
)
if subscriber_queue is None:
@@ -691,12 +664,12 @@ async def resume_session_stream(
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -747,229 +720,6 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@@ -1050,9 +800,6 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)

File diff suppressed because it is too large Load Diff

View File

@@ -144,6 +144,7 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
@@ -178,7 +179,6 @@ async def test_add_agent_to_library(mocker):
"agentGraphVersion": 1,
}
},
include={"AgentGraph": True},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args

View File

@@ -0,0 +1,10 @@
class FolderValidationError(Exception):
"""Raised when folder operations fail validation."""
pass
class FolderAlreadyExistsError(FolderValidationError):
"""Raised when a folder with the same name already exists in the location."""
pass

View File

@@ -26,6 +26,95 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR"
# === Folder Models ===
class LibraryFolder(pydantic.BaseModel):
"""Represents a folder for organizing library agents."""
id: str
user_id: str
name: str
icon: str | None = None
color: str | None = None
parent_id: str | None = None
created_at: datetime.datetime
updated_at: datetime.datetime
agent_count: int = 0 # Direct agents in folder
subfolder_count: int = 0 # Direct child folders
@staticmethod
def from_db(
folder: prisma.models.LibraryFolder,
agent_count: int = 0,
subfolder_count: int = 0,
) -> "LibraryFolder":
"""Factory method that constructs a LibraryFolder from a Prisma model."""
return LibraryFolder(
id=folder.id,
user_id=folder.userId,
name=folder.name,
icon=folder.icon,
color=folder.color,
parent_id=folder.parentId,
created_at=folder.createdAt,
updated_at=folder.updatedAt,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
class LibraryFolderTree(LibraryFolder):
"""Folder with nested children for tree view."""
children: list["LibraryFolderTree"] = []
class FolderCreateRequest(pydantic.BaseModel):
"""Request model for creating a folder."""
name: str = pydantic.Field(..., min_length=1, max_length=100)
icon: str | None = None
color: str | None = pydantic.Field(
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
)
parent_id: str | None = None
class FolderUpdateRequest(pydantic.BaseModel):
"""Request model for updating a folder."""
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
icon: str | None = None
color: str | None = None
class FolderMoveRequest(pydantic.BaseModel):
"""Request model for moving a folder to a new parent."""
target_parent_id: str | None = None # None = move to root
class BulkMoveAgentsRequest(pydantic.BaseModel):
"""Request model for moving multiple agents to a folder."""
agent_ids: list[str]
folder_id: str | None = None # None = move to root
class FolderListResponse(pydantic.BaseModel):
"""Response schema for a list of folders."""
folders: list[LibraryFolder]
pagination: Pagination
class FolderTreeResponse(pydantic.BaseModel):
"""Response schema for folder tree structure."""
tree: list[LibraryFolderTree]
class MarketplaceListingCreator(pydantic.BaseModel):
"""Creator information for a marketplace listing."""
@@ -120,6 +209,9 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -259,6 +351,8 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
@@ -470,3 +564,7 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
settings: Optional[GraphSettings] = pydantic.Field(
default=None, description="User-specific settings for this library agent"
)
folder_id: Optional[str] = pydantic.Field(
default=None,
description="Folder ID to move agent to (None to move to root)",
)

View File

@@ -1,9 +1,11 @@
import fastapi
from .agents import router as agents_router
from .folders import router as folders_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(folders_router)
router.include_router(agents_router)

View File

@@ -41,6 +41,14 @@ async def list_library_agents(
ge=1,
description="Number of agents per page (must be >= 1)",
),
folder_id: Optional[str] = Query(
None,
description="Filter by folder ID",
),
include_root_only: bool = Query(
False,
description="Only return agents without a folder (root-level agents)",
),
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
@@ -51,6 +59,8 @@ async def list_library_agents(
sort_by=sort_by,
page=page,
page_size=page_size,
folder_id=folder_id,
include_root_only=include_root_only,
)
@@ -168,6 +178,7 @@ async def update_library_agent(
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
folder_id=payload.folder_id,
)

View File

@@ -0,0 +1,287 @@
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Query, Security, status
from fastapi.responses import Response
from .. import db as library_db
from .. import model as library_model
router = APIRouter(
prefix="/folders",
tags=["library", "folders", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@router.get(
"",
summary="List Library Folders",
response_model=library_model.FolderListResponse,
responses={
200: {"description": "List of folders"},
500: {"description": "Server error"},
},
)
async def list_folders(
user_id: str = Security(autogpt_auth_lib.get_user_id),
parent_id: Optional[str] = Query(
None,
description="Filter by parent folder ID. If not provided, returns root-level folders.",
),
include_relations: bool = Query(
True,
description="Include agent and subfolder relations (for counts)",
),
) -> library_model.FolderListResponse:
"""
List folders for the authenticated user.
Args:
user_id: ID of the authenticated user.
parent_id: Optional parent folder ID to filter by.
include_relations: Whether to include agent and subfolder relations for counts.
Returns:
A FolderListResponse containing folders.
"""
folders = await library_db.list_folders(
user_id=user_id,
parent_id=parent_id,
include_relations=include_relations,
)
return library_model.FolderListResponse(
folders=folders,
pagination=library_model.Pagination(
total_items=len(folders),
total_pages=1,
current_page=1,
page_size=len(folders),
),
)
@router.get(
"/tree",
summary="Get Folder Tree",
response_model=library_model.FolderTreeResponse,
responses={
200: {"description": "Folder tree structure"},
500: {"description": "Server error"},
},
)
async def get_folder_tree(
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.FolderTreeResponse:
"""
Get the full folder tree for the authenticated user.
Args:
user_id: ID of the authenticated user.
Returns:
A FolderTreeResponse containing the nested folder structure.
"""
tree = await library_db.get_folder_tree(user_id=user_id)
return library_model.FolderTreeResponse(tree=tree)
@router.get(
"/{folder_id}",
summary="Get Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder details"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def get_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Get a specific folder.
Args:
folder_id: ID of the folder to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryFolder.
"""
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
@router.post(
"",
summary="Create Folder",
status_code=status.HTTP_201_CREATED,
response_model=library_model.LibraryFolder,
responses={
201: {"description": "Folder created successfully"},
400: {"description": "Validation error"},
404: {"description": "Parent folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def create_folder(
payload: library_model.FolderCreateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Create a new folder.
Args:
payload: The folder creation request.
user_id: ID of the authenticated user.
Returns:
The created LibraryFolder.
"""
return await library_db.create_folder(
user_id=user_id,
name=payload.name,
parent_id=payload.parent_id,
icon=payload.icon,
color=payload.color,
)
@router.patch(
"/{folder_id}",
summary="Update Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder updated successfully"},
400: {"description": "Validation error"},
404: {"description": "Folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def update_folder(
folder_id: str,
payload: library_model.FolderUpdateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Update a folder's properties.
Args:
folder_id: ID of the folder to update.
payload: The folder update request.
user_id: ID of the authenticated user.
Returns:
The updated LibraryFolder.
"""
return await library_db.update_folder(
folder_id=folder_id,
user_id=user_id,
name=payload.name,
icon=payload.icon,
color=payload.color,
)
@router.post(
"/{folder_id}/move",
summary="Move Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder moved successfully"},
400: {"description": "Validation error (circular reference)"},
404: {"description": "Folder or target parent not found"},
409: {"description": "Folder name conflict in target location"},
500: {"description": "Server error"},
},
)
async def move_folder(
folder_id: str,
payload: library_model.FolderMoveRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Move a folder to a new parent.
Args:
folder_id: ID of the folder to move.
payload: The move request with target parent.
user_id: ID of the authenticated user.
Returns:
The moved LibraryFolder.
"""
return await library_db.move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=payload.target_parent_id,
)
@router.delete(
"/{folder_id}",
summary="Delete Folder",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Folder deleted successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def delete_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> Response:
"""
Soft-delete a folder and all its contents.
Args:
folder_id: ID of the folder to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
"""
await library_db.delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
# === Bulk Agent Operations ===
@router.post(
"/agents/bulk-move",
summary="Bulk Move Agents",
response_model=list[library_model.LibraryAgent],
responses={
200: {"description": "Agents moved successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def bulk_move_agents(
payload: library_model.BulkMoveAgentsRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> list[library_model.LibraryAgent]:
"""
Move multiple agents to a folder.
Args:
payload: The bulk move request with agent IDs and target folder.
user_id: ID of the authenticated user.
Returns:
The updated LibraryAgents.
"""
return await library_db.bulk_move_agents_to_folder(
agent_ids=payload.agent_ids,
folder_id=payload.folder_id,
user_id=user_id,
)

View File

@@ -115,6 +115,8 @@ async def test_get_library_agents_success(
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
folder_id=None,
include_root_only=False,
)

View File

@@ -9,15 +9,26 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, get_args, get_origin
from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -188,45 +199,51 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if hasattr(block_instance, "name") and block_instance.name:
if block_instance.name:
parts.append(block_instance.name)
if (
hasattr(block_instance, "description")
and block_instance.description
):
if block_instance.description:
parts.append(block_instance.description)
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
if block_instance.categories:
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block_input_fields.items()
if field_info.description
]
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in categories] if categories else []
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
for info in credentials_info.values()
for provider in info.provider
]
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block_instance.input_schema.model_fields.values()
)
items.append(
@@ -235,8 +252,11 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": getattr(block_instance, "name", ""),
"name": block_instance.name,
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None, # Blocks are public
)

View File

@@ -82,9 +82,10 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
@@ -309,19 +310,19 @@ async def test_content_handlers_registry():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
handler = BlockHandler()
# Mock block with minimal attributes
# Mock block with empty values (all attributes exist but are falsy)
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
@@ -352,6 +353,8 @@ async def test_block_handler_skips_failed_blocks():
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()

View File

@@ -71,46 +71,41 @@ async def upload_media(
logger.error(f"Error reading file content: {str(e)}")
raise store_exceptions.FileReadError("Failed to read file content") from e
# Validate file signature/magic bytes
if file.content_type in ALLOWED_IMAGE_TYPES:
# Check image file signatures
if content.startswith(b"\xff\xd8\xff"): # JPEG
if file.content_type != "image/jpeg":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
if file.content_type != "image/png":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
if file.content_type != "image/gif":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"RIFF") and content[8:12] == b"WEBP": # WebP
if file.content_type != "image/webp":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
else:
raise store_exceptions.InvalidFileTypeError("Invalid image file signature")
# Detect actual content type from file signature/magic bytes
# Trust the file signature over the declared content-type header
detected_content_type: str | None = None
elif file.content_type in ALLOWED_VIDEO_TYPES:
# Check video file signatures
if content.startswith(b"\x00\x00\x00") and (content[4:8] == b"ftyp"): # MP4
if file.content_type != "video/mp4":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
if file.content_type != "video/webm":
raise store_exceptions.InvalidFileTypeError(
"File signature does not match content type"
)
else:
raise store_exceptions.InvalidFileTypeError("Invalid video file signature")
# Check image file signatures
if content.startswith(b"\xff\xd8\xff"): # JPEG
detected_content_type = "image/jpeg"
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
detected_content_type = "image/png"
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
detected_content_type = "image/gif"
elif content.startswith(b"RIFF") and len(content) >= 12 and content[8:12] == b"WEBP": # WebP
detected_content_type = "image/webp"
# Check video file signatures
elif content.startswith(b"\x00\x00\x00") and len(content) >= 8 and content[4:8] == b"ftyp": # MP4
detected_content_type = "video/mp4"
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
detected_content_type = "video/webm"
# If we detected a valid type, use it; otherwise reject the file
if detected_content_type is None:
raise store_exceptions.InvalidFileTypeError(
"Could not detect a valid image or video file signature. "
"Supported formats: JPEG, PNG, GIF, WebP, MP4, WebM"
)
# Log if we're auto-correcting a mismatched content-type
if file.content_type != detected_content_type:
logger.info(
f"Auto-correcting content-type from '{file.content_type}' to "
f"'{detected_content_type}' based on file signature"
)
# Use the detected content type going forward
content_type = detected_content_type
settings = Settings()
@@ -122,19 +117,7 @@ async def upload_media(
)
try:
# Validate file type
content_type = file.content_type
if content_type is None:
content_type = "image/jpeg"
if (
content_type not in ALLOWED_IMAGE_TYPES
and content_type not in ALLOWED_VIDEO_TYPES
):
logger.warning(f"Invalid file type attempted: {content_type}")
raise store_exceptions.InvalidFileTypeError(
f"File type not supported. Must be jpeg, png, gif, webp, mp4 or webm. Content type: {content_type}"
)
# content_type is already validated from file signature detection above
# Validate file size
file_size = 0

View File

@@ -191,23 +191,35 @@ async def test_upload_media_webm_success(mock_settings, mock_storage_client):
assert result.endswith(".webm")
async def test_upload_media_mismatched_signature(mock_settings, mock_storage_client):
async def test_upload_media_mismatched_signature_auto_corrects(
mock_settings, mock_storage_client
):
"""Test that mismatched content-type is auto-corrected based on file signature."""
test_file = fastapi.UploadFile(
filename="test.jpeg",
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # PNG signature with JPEG content type
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(store_exceptions.InvalidFileTypeError):
await store_media.upload_media("test-user", test_file)
# Should auto-correct to PNG and succeed
result = await store_media.upload_media("test-user", test_file)
assert result.startswith(
"https://storage.googleapis.com/test-bucket/users/test-user/images/"
)
# File should be stored as PNG based on actual content
mock_storage_client.upload.assert_called_once()
async def test_upload_media_invalid_signature(mock_settings, mock_storage_client):
"""Test that files with unrecognized signatures are rejected."""
test_file = fastapi.UploadFile(
filename="test.jpeg",
file=io.BytesIO(b"invalid signature"),
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
)
with pytest.raises(store_exceptions.InvalidFileTypeError):
with pytest.raises(store_exceptions.InvalidFileTypeError) as exc_info:
await store_media.upload_media("test-user", test_file)
assert "Could not detect a valid image or video file signature" in str(
exc_info.value
)

View File

@@ -126,6 +126,9 @@ v1_router = APIRouter()
########################################################
_tally_background_tasks: set[asyncio.Task] = set()
@v1_router.post(
"/auth/user",
summary="Get or create user",
@@ -134,6 +137,24 @@ v1_router = APIRouter()
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
from backend.data.tally import populate_understanding_from_tally
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
logger.debug("Failed to start Tally population task", exc_info=True)
return user.model_dump()

View File

@@ -1,5 +1,5 @@
import json
from datetime import datetime
from datetime import datetime, timezone
from io import BytesIO
from unittest.mock import AsyncMock, Mock, patch
@@ -43,6 +43,7 @@ def test_get_or_create_user_route(
) -> None:
"""Test get or create user endpoint"""
mock_user = Mock()
mock_user.created_at = datetime.now(timezone.utc)
mock_user.model_dump.return_value = {
"id": test_user_id,
"email": "test@example.com",

View File

@@ -41,11 +41,11 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
from backend.api.features.library.exceptions import (
FolderAlreadyExistsError,
FolderValidationError,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -123,21 +123,9 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
@@ -277,6 +265,10 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(RequestValidationError, validation_error_handler)

View File

@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
# Run the last process in the foreground.
processes[-1].start(background=False, **kwargs)
finally:
for process in processes:
for process in reversed(processes):
try:
process.stop()
except Exception as e:

View File

@@ -0,0 +1,182 @@
"""
Telegram Bot API helper functions.
Provides utilities for making authenticated requests to the Telegram Bot API.
"""
import logging
from io import BytesIO
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
TELEGRAM_API_BASE = "https://api.telegram.org"
class TelegramMessageResult(BaseModel, extra="allow"):
"""Result from Telegram send/edit message API calls."""
message_id: int = 0
chat: dict[str, Any] = {}
date: int = 0
text: str = ""
class TelegramFileResult(BaseModel, extra="allow"):
"""Result from Telegram getFile API call."""
file_id: str = ""
file_unique_id: str = ""
file_size: int = 0
file_path: str = ""
class TelegramAPIException(ValueError):
"""Exception raised for Telegram API errors."""
def __init__(self, message: str, error_code: int = 0):
super().__init__(message)
self.error_code = error_code
def get_bot_api_url(bot_token: str, method: str) -> str:
"""Construct Telegram Bot API URL for a method."""
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
def get_file_url(bot_token: str, file_path: str) -> str:
"""Construct Telegram file download URL."""
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
async def call_telegram_api(
credentials: APIKeyCredentials,
method: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a request to the Telegram Bot API.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendMessage", "getFile")
data: Request parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
response = await Requests().post(url, json=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def call_telegram_api_with_file(
credentials: APIKeyCredentials,
method: str,
file_field: str,
file_data: bytes,
filename: str,
content_type: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a multipart/form-data request to the Telegram Bot API with a file upload.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendPhoto", "sendVoice")
file_field: Form field name for the file (e.g., "photo", "voice")
file_data: Raw file bytes
filename: Filename for the upload
content_type: MIME type of the file
data: Additional form parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
files = [(file_field, (filename, BytesIO(file_data), content_type))]
response = await Requests().post(url, files=files, data=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def get_file_info(
credentials: APIKeyCredentials, file_id: str
) -> TelegramFileResult:
"""
Get file information from Telegram.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
File info dict containing file_id, file_unique_id, file_size, file_path
"""
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
return TelegramFileResult(**result.model_dump())
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
"""
Get the download URL for a Telegram file.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
Full download URL
"""
token = credentials.api_key.get_secret_value()
result = await get_file_info(credentials, file_id)
file_path = result.file_path
if not file_path:
raise TelegramAPIException("No file_path returned from getFile")
return get_file_url(token, file_path)
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
"""
Download a file from Telegram servers.
Args:
credentials: Bot token credentials
file_id: Telegram file_id
Returns:
File content as bytes
"""
url = await get_file_download_url(credentials, file_id)
response = await Requests().get(url)
return response.content

View File

@@ -0,0 +1,43 @@
"""
Telegram Bot credentials handling.
Telegram bots use an API key (bot token) obtained from @BotFather.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Bot token credentials (API key style)
TelegramCredentials = APIKeyCredentials
TelegramCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.TELEGRAM], Literal["api_key"]
]
def TelegramCredentialsField() -> TelegramCredentialsInput:
"""Creates a Telegram bot token credentials field."""
return CredentialsField(
description="Telegram Bot API token from @BotFather. "
"Create a bot at https://t.me/BotFather to get your token."
)
# Test credentials for unit tests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="telegram",
api_key=SecretStr("test_telegram_bot_token"),
title="Mock Telegram Bot Token",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,377 @@
"""
Telegram trigger blocks for receiving messages via webhooks.
"""
import logging
from pydantic import BaseModel
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.telegram import TelegramWebhookType
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TelegramCredentialsField,
TelegramCredentialsInput,
)
logger = logging.getLogger(__name__)
# Example payload for testing
EXAMPLE_MESSAGE_PAYLOAD = {
"update_id": 123456789,
"message": {
"message_id": 1,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"language_code": "en",
},
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"date": 1234567890,
"text": "Hello, bot!",
},
}
class TelegramTriggerBase:
"""Base class for Telegram trigger blocks."""
class Input(BlockSchemaInput):
credentials: TelegramCredentialsInput = TelegramCredentialsField()
payload: dict = SchemaField(hidden=True, default_factory=dict)
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a message is received or edited in your Telegram bot.
Supports text, photos, voice messages, audio files, documents, and videos.
Connect the outputs to other blocks to process messages and send responses.
"""
class Input(TelegramTriggerBase.Input):
class EventsFilter(BaseModel):
"""Filter for message types to receive."""
text: bool = True
photo: bool = False
voice: bool = False
audio: bool = False
document: bool = False
video: bool = False
edited_message: bool = False
events: EventsFilter = SchemaField(
title="Message Types", description="Types of messages to receive"
)
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the message was received. "
"Use this to send replies."
)
message_id: int = SchemaField(description="The unique message ID")
user_id: int = SchemaField(description="The user ID who sent the message")
username: str = SchemaField(description="Username of the sender (may be empty)")
first_name: str = SchemaField(description="First name of the sender")
event: str = SchemaField(
description="The message type (text, photo, voice, audio, etc.)"
)
text: str = SchemaField(
description="Text content of the message (for text messages)"
)
photo_file_id: str = SchemaField(
description="File ID of the photo (for photo messages). "
"Use GetTelegramFileBlock to download."
)
voice_file_id: str = SchemaField(
description="File ID of the voice message (for voice messages). "
"Use GetTelegramFileBlock to download."
)
audio_file_id: str = SchemaField(
description="File ID of the audio file (for audio messages). "
"Use GetTelegramFileBlock to download."
)
file_id: str = SchemaField(
description="File ID for document/video messages. "
"Use GetTelegramFileBlock to download."
)
file_name: str = SchemaField(
description="Original filename (for document/audio messages)"
)
caption: str = SchemaField(description="Caption for media messages")
is_edited: bool = SchemaField(
description="Whether this is an edit of a previously sent message"
)
def __init__(self):
super().__init__(
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
description="Triggers when a message is received or edited in your Telegram bot. "
"Supports text, photos, voice messages, audio files, documents, and videos.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageTriggerBlock.Input,
output_schema=TelegramMessageTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="events",
event_format="message.{event}",
),
test_input={
"events": {"text": True, "photo": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_MESSAGE_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_MESSAGE_PAYLOAD),
("chat_id", 12345678),
("message_id", 1),
("user_id", 12345678),
("username", "johndoe"),
("first_name", "John"),
("is_edited", False),
("event", "text"),
("text", "Hello, bot!"),
("photo_file_id", ""),
("voice_file_id", ""),
("audio_file_id", ""),
("file_id", ""),
("file_name", ""),
("caption", ""),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
is_edited = "edited_message" in payload
message = payload.get("message") or payload.get("edited_message", {})
# Extract common fields
chat = message.get("chat", {})
sender = message.get("from", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", message.get("message_id", 0)
yield "user_id", sender.get("id", 0)
yield "username", sender.get("username", "")
yield "first_name", sender.get("first_name", "")
yield "is_edited", is_edited
# For edited messages, yield event as "edited_message" and extract
# all content fields from the edited message body
if is_edited:
yield "event", "edited_message"
yield "text", message.get("text", "")
photos = message.get("photo", [])
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
voice = message.get("voice", {})
yield "voice_file_id", voice.get("file_id", "")
audio = message.get("audio", {})
yield "audio_file_id", audio.get("file_id", "")
document = message.get("document", {})
video = message.get("video", {})
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
yield "file_name", (
document.get("file_name", "") or audio.get("file_name", "")
)
yield "caption", message.get("caption", "")
# Determine message type and extract content
elif "text" in message:
yield "event", "text"
yield "text", message.get("text", "")
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
elif "photo" in message:
# Get the largest photo (last in array)
photos = message.get("photo", [])
photo_fid = photos[-1].get("file_id", "") if photos else ""
yield "event", "photo"
yield "text", ""
yield "photo_file_id", photo_fid
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "voice" in message:
voice = message.get("voice", {})
yield "event", "voice"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", voice.get("file_id", "")
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "audio" in message:
audio = message.get("audio", {})
yield "event", "audio"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", audio.get("file_id", "")
yield "file_id", ""
yield "file_name", audio.get("file_name", "")
yield "caption", message.get("caption", "")
elif "document" in message:
document = message.get("document", {})
yield "event", "document"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", document.get("file_id", "")
yield "file_name", document.get("file_name", "")
yield "caption", message.get("caption", "")
elif "video" in message:
video = message.get("video", {})
yield "event", "video"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", video.get("file_id", "")
yield "file_name", video.get("file_name", "")
yield "caption", message.get("caption", "")
else:
yield "event", "other"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
# Example payload for reaction trigger testing
EXAMPLE_REACTION_PAYLOAD = {
"update_id": 123456790,
"message_reaction": {
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"message_id": 42,
"user": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"username": "johndoe",
},
"date": 1234567890,
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
"old_reaction": [],
},
}
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a reaction to a message is changed.
Works automatically in private chats. In group chats, the bot must be
an administrator to receive reaction updates.
"""
class Input(TelegramTriggerBase.Input):
pass
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the reaction occurred"
)
message_id: int = SchemaField(description="The message ID that was reacted to")
user_id: int = SchemaField(description="The user ID who changed the reaction")
username: str = SchemaField(description="Username of the user (may be empty)")
new_reactions: list = SchemaField(
description="List of new reactions on the message"
)
old_reactions: list = SchemaField(
description="List of previous reactions on the message"
)
def __init__(self):
super().__init__(
id="82525328-9368-4966-8f0c-cd78e80181fd",
description="Triggers when a reaction to a message is changed. "
"Works in private chats automatically. "
"In groups, the bot must be an administrator.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageReactionTriggerBlock.Input,
output_schema=TelegramMessageReactionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="",
event_format="message_reaction",
),
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_REACTION_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_REACTION_PAYLOAD),
("chat_id", 12345678),
("message_id", 42),
("user_id", 12345678),
("username", "johndoe"),
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
("old_reactions", []),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
reaction = payload.get("message_reaction", {})
chat = reaction.get("chat", {})
user = reaction.get("user", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", reaction.get("message_id", 0)
yield "user_id", user.get("id", 0)
yield "username", user.get("username", "")
yield "new_reactions", reaction.get("new_reaction", [])
yield "old_reactions", reaction.get("old_reaction", [])

View File

@@ -1,349 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import uuid
from typing import Any
import orjson
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
Database operations are handled through the chat_db() accessor, which
routes through DatabaseManager RPC when Prisma is not directly connected.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
await process_operation_success(task, message.result)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
await process_operation_failure(task, message.error)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,329 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from backend.data.db_accessors import chat_db
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
) -> None:
"""Update tool message in database using the chat_db accessor.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
Raises:
ToolMessageUpdateError: If the database update fails.
"""
try:
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=content,
)
if not updated:
raise ToolMessageUpdateError(
f"No message found with tool_call_id="
f"{tool_call_id} in session {session_id}"
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
Raises:
ToolMessageUpdateError: If the database update fails. The task
will be marked as failed instead of completed.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database
with the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -36,54 +36,29 @@ class ChatConfig(BaseSettings):
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_lock_ttl: int = Field(
default=120,
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
"reconnection after refresh/crash without long waits.",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
session_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
description="Prefix for session metadata hash keys",
)
task_stream_prefix: str = Field(
turn_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
description="Prefix for turn message stream keys",
)
# Langfuse Prompt Management Configuration
@@ -110,7 +85,7 @@ class ChatConfig(BaseSettings):
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
@@ -155,14 +130,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):

View File

@@ -3,8 +3,9 @@
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any, cast
from typing import Any
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
@@ -92,10 +93,9 @@ async def add_chat_message(
function_call: dict[str, Any] | None = None,
) -> ChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
data: ChatMessageCreateInput = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
@@ -123,7 +123,7 @@ async def add_chat_message(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
PrismaChatMessage.prisma().create(data=data),
)
return ChatMessage.from_db(message)
@@ -132,38 +132,42 @@ async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> tuple[list[ChatMessage], int]:
) -> int:
"""Add multiple messages to a chat session in a batch.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries MAX(sequence)
and retries with the correct next sequence number. This avoids
unnecessary upserts and DB queries in the common case (no collision).
streaming loop and long-running callback race), queries the latest
sequence and retries with the correct offset. This avoids unnecessary
upserts and DB queries in the common case (no collision).
Returns:
Tuple of (messages, final_message_count) where final_message_count
is the total number of messages in the session after insertion.
This allows callers to update their counters even when collision
detection adjusts start_sequence.
Next sequence number for the next message to be inserted. This equals
start_sequence + len(messages) and allows callers to update their
counters even when collision detection adjusts start_sequence.
"""
if not messages:
# No messages to add - return current count
return [], start_sequence
return start_sequence
max_retries = 3
max_retries = 5
for attempt in range(max_retries):
try:
created_messages = []
# Single timestamp for all messages and session update
now = datetime.now(UTC)
async with db.transaction() as tx:
# Build all message data
messages_data = []
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
# Note: create_many doesn't support nested creates, use sessionId directly
data: ChatMessageCreateInput = {
"sessionId": session_id,
"role": msg["role"],
"sequence": start_sequence + i,
"createdAt": now,
}
# Add optional string fields
@@ -182,31 +186,23 @@ async def add_chat_messages_batch(
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
messages_data.append(data)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
# Run create_many and session update in parallel within transaction
# Both use the same timestamp for consistency
await asyncio.gather(
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": now},
),
)
# Return messages and final message count (for shared counter sync)
final_count = start_sequence + len(messages)
return [ChatMessage.from_db(m) for m in created_messages], final_count
# Return next sequence number for counter sync
return start_sequence + len(messages)
except Exception as e:
# Check if it's a unique constraint violation
error_msg = str(e).lower()
is_unique_constraint = (
"unique constraint" in error_msg or "duplicate key" in error_msg
)
if is_unique_constraint and attempt < max_retries - 1:
except UniqueViolationError:
if attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
@@ -218,7 +214,7 @@ async def add_chat_messages_batch(
)
continue
else:
# Not a collision or max retries exceeded - propagate error
# Max retries exceeded - propagate error
raise
# Should never reach here due to raise in exception handler
@@ -281,18 +277,15 @@ async def get_next_sequence(session_id: str) -> int:
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
Optimized to select only the sequence column using raw SQL.
The unique index on (sessionId, sequence) makes this query fast.
"""
result = await db.prisma.query_raw(
"""
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
FROM "ChatMessage"
WHERE "sessionId" = $1
""",
results = await db.query_raw_with_schema(
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
session_id,
)
if not result or len(result) == 0:
return 0
return int(result[0]["next_seq"])
return 0 if not results else results[0]["sequence"] + 1
async def update_tool_message_content(

View File

@@ -4,6 +4,7 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
"""
import asyncio
import logging
import os
import threading
@@ -25,7 +26,7 @@ from backend.util.process import AppProcess
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_task, init_worker
from .processor import execute_copilot_turn, init_worker
from .utils import (
COPILOT_CANCEL_QUEUE_NAME,
COPILOT_EXECUTION_QUEUE_NAME,
@@ -181,13 +182,13 @@ class CoPilotExecutor(AppProcess):
self._executor.shutdown(wait=False)
# Release any remaining locks
for task_id, lock in list(self._task_locks.items()):
for session_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
@@ -267,20 +268,20 @@ class CoPilotExecutor(AppProcess):
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
task_id = request.task_id
if not task_id:
logger.warning("Cancel message missing 'task_id'")
session_id = request.session_id
if not session_id:
logger.warning("Cancel message missing 'session_id'")
return
if task_id not in self.active_tasks:
logger.debug(f"Cancel received for {task_id} but not active")
if session_id not in self.active_tasks:
logger.debug(f"Cancel received for {session_id} but not active")
return
_, cancel_event = self.active_tasks[task_id]
logger.info(f"Received cancel for {task_id}")
_, cancel_event = self.active_tasks[session_id]
logger.info(f"Received cancel for {session_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {task_id}")
logger.debug(f"Cancel already set for {session_id}")
def _handle_run_message(
self,
@@ -352,12 +353,12 @@ class CoPilotExecutor(AppProcess):
ack_message(reject=True, requeue=False)
return
task_id = entry.task_id
session_id = entry.session_id
# Check for local duplicate - task is already running on this executor
if task_id in self.active_tasks:
# Check for local duplicate - session is already running on this executor
if session_id in self.active_tasks:
logger.warning(
f"Task {task_id} already running locally, rejecting duplicate"
f"Session {session_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
@@ -365,64 +366,69 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:task:{task_id}:lock",
key=f"copilot:session:{session_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(f"Task {task_id} already running on pod {current_owner}")
logger.warning(
f"Session {session_id} already running on pod {current_owner}"
)
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {task_id} - Redis unavailable"
f"Could not acquire lock for {session_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[task_id] = cluster_lock
self._task_locks[session_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_task, entry, cancel_event, cluster_lock
execute_copilot_turn, entry, cancel_event, cluster_lock
)
self.active_tasks[task_id] = (future, cancel_event)
self.active_tasks[session_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {task_id}: {e}")
logger.warning(f"Failed to setup execution for {session_id}: {e}")
cluster_lock.release()
if task_id in self._task_locks:
del self._task_locks[task_id]
if session_id in self._task_locks:
del self._task_locks[session_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {task_id}")
logger.info(f"Run completed for {session_id}")
error_msg = None
try:
if exec_error := f.exception():
logger.error(f"Execution for {task_id} failed: {exec_error}")
# Don't requeue failed tasks - they've been marked as failed
# in the stream registry. Requeuing would cause infinite retries
# for deterministic failures.
error_msg = str(exec_error) or type(exec_error).__name__
logger.error(f"Execution for {session_id} failed: {error_msg}")
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
except asyncio.CancelledError:
logger.info(f"Run completion callback cancelled for {session_id}")
except BaseException as e:
logger.exception(f"Error in run completion callback: {e}")
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
finally:
# Release the cluster lock
if task_id in self._task_locks:
logger.info(f"Releasing cluster lock for {task_id}")
self._task_locks[task_id].release()
del self._task_locks[task_id]
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()
del self._task_locks[session_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
@@ -433,11 +439,11 @@ class CoPilotExecutor(AppProcess):
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for task_id, (future, _) in list(self.active_tasks.items()):
for session_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(task_id)
self.active_tasks.pop(task_id, None)
logger.info(f"Cleaned up completed task {task_id}")
completed_tasks.append(session_id)
self.active_tasks.pop(session_id, None)
logger.info(f"Cleaned up completed session {session_id}")
self._update_metrics()
return completed_tasks

View File

@@ -1,6 +1,6 @@
"""CoPilot execution processor - per-worker execution logic.
This module contains the processor class that handles CoPilot task execution
This module contains the processor class that handles CoPilot session execution
in a thread-local context, following the graph executor pattern.
"""
@@ -12,7 +12,7 @@ import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
@@ -32,17 +32,17 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
_tls = threading.local()
def execute_copilot_task(
def execute_copilot_turn(
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task using the thread-local processor.
"""Execute a single CoPilot turn (user message → AI response).
This function is the entry point called by the thread pool executor.
Args:
entry: The task payload
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for this execution
"""
@@ -76,16 +76,16 @@ def cleanup_worker():
class CoPilotProcessor:
"""Per-worker execution logic for CoPilot tasks.
"""Per-worker execution logic for CoPilot sessions.
This class is instantiated once per worker thread and handles the execution
of CoPilot chat generation tasks. It maintains an async event loop for
of CoPilot chat generation sessions. It maintains an async event loop for
running the async service code.
The execution flow:
1. CoPilot task is picked from RabbitMQ queue
2. Manager submits task to thread pool
3. Processor executes the task in its event loop
1. Session entry is picked from RabbitMQ queue
2. Manager submits to thread pool
3. Processor executes in its event loop
4. Results are published to Redis Streams
"""
@@ -125,7 +125,10 @@ class CoPilotProcessor:
)
future.result(timeout=5)
except Exception as e:
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
)
# Stop the event loop
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
@@ -139,19 +142,17 @@ class CoPilotProcessor:
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task.
"""Execute a CoPilot turn.
This is the main entry point for task execution. It runs the async
execution logic in the worker's event loop and handles errors.
Runs the async logic in the worker's event loop and handles errors.
Args:
entry: The task payload containing session and message info
entry: The turn payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
@@ -159,38 +160,30 @@ class CoPilotProcessor:
start_time = time.monotonic()
try:
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
except Exception as e:
elapsed = time.monotonic() - start_time
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
# Note: _execute_async already marks the task as failed before re-raising,
# so we don't call _mark_task_failed here to avoid duplicate error events.
raise
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
async def _execute_async(
self,
@@ -199,19 +192,20 @@ class CoPilotProcessor:
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Async execution logic for CoPilot task.
"""Async execution logic for a CoPilot turn.
This method calls the existing stream_chat_completion service function
and publishes results to the stream registry.
Calls the stream_chat_completion service function and publishes
results to the stream registry.
Args:
entry: The task payload
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for refresh
log: Structured logger for this task
log: Structured logger
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
error_msg = None
try:
# Choose service based on LaunchDarkly flag
@@ -228,7 +222,7 @@ class CoPilotProcessor:
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
@@ -236,56 +230,47 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
):
# Check for cancellation
if cancel.is_set():
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.task_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
)
return
log.info("Cancel requested, breaking stream")
break
# Refresh cluster lock periodically
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
# Skip StreamFinish — mark_session_completed publishes it.
if isinstance(chunk, StreamFinish):
continue
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
log.info("Task completed successfully")
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(
entry.task_id,
status="failed",
error_message="Task was cancelled",
)
# Stream loop completed
if cancel.is_set():
log.info("Stream cancelled by user")
except BaseException as e:
# Handle all exceptions (including CancelledError) with appropriate logging
if isinstance(e, asyncio.CancelledError):
log.info("Turn cancelled")
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
log.error(f"Turn failed: {error_msg}")
raise
except Exception as e:
log.error(f"Task failed: {e}")
await self._mark_task_failed(entry.task_id, str(e))
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")
finally:
# If no exception but user cancelled, still mark as cancelled
if not error_msg and cancel.is_set():
error_msg = "Operation cancelled"
try:
await stream_registry.mark_session_completed(
entry.session_id, error_message=error_msg
)
except Exception as mark_err:
log.error(f"Failed to mark session completed: {mark_err}")

View File

@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
@@ -135,18 +135,15 @@ class CoPilotExecutionEntry(BaseModel):
This model represents a chat generation task to be processed by the executor.
"""
task_id: str
"""Unique identifier for this task (used for stream registry)"""
session_id: str
"""Chat session ID"""
"""Chat session ID (also used for dedup/locking)"""
turn_id: str = ""
"""Per-turn UUID for Redis stream isolation"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
operation_id: str
"""Operation ID for webhook callbacks and completion tracking"""
message: str
"""User's message to process"""
@@ -160,40 +157,37 @@ class CoPilotExecutionEntry(BaseModel):
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
task_id: str
"""Task ID to cancel"""
session_id: str
"""Session ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_task(
task_id: str,
async def enqueue_copilot_turn(
session_id: str,
user_id: str | None,
operation_id: str,
message: str,
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
task_id: Unique identifier for this task (used for stream registry)
session_id: Chat session ID
session_id: Chat session ID (also used for dedup/locking)
user_id: User ID (may be None for anonymous users)
operation_id: Operation ID for webhook callbacks and completion tracking
message: User's message to process
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
task_id=task_id,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
operation_id=operation_id,
message=message,
is_user_message=is_user_message,
context=context,
@@ -207,15 +201,15 @@ async def enqueue_copilot_task(
)
async def enqueue_cancel_task(task_id: str) -> None:
"""Publish a cancel request for a running CoPilot task.
async def enqueue_cancel_task(session_id: str) -> None:
"""Publish a cancel request for a running CoPilot session.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(task_id=task_id)
event = CancelCoPilotEvent(session_id=session_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key

View File

@@ -434,25 +434,13 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
async def upsert_chat_session(
session: ChatSession,
*,
existing_message_count: int | None = None,
) -> tuple[ChatSession, int]:
) -> ChatSession:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
operations (e.g., background title update and main stream handler)
attempt to upsert the same session simultaneously.
Args:
existing_message_count: If provided, skip the DB query to count
existing messages. The caller is responsible for tracking this
accurately. Useful for incremental saves in a streaming loop
where the caller already knows how many messages are persisted.
Returns:
Tuple of (session, final_message_count) where final_message_count is
the actual persisted message count after collision detection adjustments.
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
@@ -463,18 +451,14 @@ async def upsert_chat_session(
lock = await _get_session_lock(session.session_id)
async with lock:
# Get existing message count from DB for incremental saves
if existing_message_count is None:
existing_message_count = await chat_db().get_next_sequence(
session.session_id
)
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
db_error: Exception | None = None
final_count = existing_message_count
# Save to database (primary storage)
try:
final_count = await _save_session_to_db(
await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
@@ -505,7 +489,7 @@ async def upsert_chat_session(
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session, final_count
return session
async def _save_session_to_db(
@@ -513,16 +497,13 @@ async def _save_session_to_db(
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> int:
) -> None:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
Returns:
Final message count after save (accounting for collision detection).
"""
db = chat_db()
@@ -554,7 +535,6 @@ async def _save_session_to_db(
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
final_count = existing_message_count
if new_messages:
messages_data = []
for msg in new_messages:
@@ -574,14 +554,12 @@ async def _save_session_to_db(
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
_, final_count = await db.add_chat_messages_batch(
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
return final_count
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.

View File

@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s, _ = await upsert_chat_session(s)
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
@@ -331,3 +331,96 @@ def test_to_openai_messages_merges_split_assistants():
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"
# --------------------------------------------------------------------------- #
# Concurrent save collision detection #
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
Simulates the race condition where:
1. Streaming loop starts with saved_msg_count=5
2. Long-running callback appends message #5 and saves
3. Streaming loop tries to save with stale count=5
The collision detection should handle this gracefully.
"""
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
)
)
# Save initial messages
session = await upsert_chat_session(session)
# Simulate streaming loop and callback saving concurrently
async def streaming_loop_save():
"""Simulates streaming loop saving messages."""
# Add 2 messages
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
session.messages.append(
ChatMessage(role="assistant", content="Streaming message 2")
)
# Wait a bit to let callback potentially save first
await asyncio.sleep(0.01)
# Save (will query DB for existing count)
return await upsert_chat_session(session)
async def callback_save():
"""Simulates long-running callback saving a message."""
# Add 1 message
session.messages.append(
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
)
# Save immediately (will query DB for existing count)
return await upsert_chat_session(session)
# Run both saves concurrently - one will hit collision detection
results = await asyncio.gather(streaming_loop_save(), callback_save())
# Both should succeed
assert all(r is not None for r in results)
# Reload session from DB to verify
from backend.data.redis_client import get_redis_async
redis_key = f"chat:session:{session.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key) # Clear cache to force DB load
loaded_session = await get_chat_session(session.session_id, test_user_id)
assert loaded_session is not None
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
assert len(loaded_session.messages) == 6
# Verify no duplicate sequences
sequences = []
for i, msg in enumerate(loaded_session.messages):
# Messages should have sequential sequence numbers starting from 0
sequences.append(i)
# All sequences should be unique and sequential
assert sequences == list(range(6))
# Verify message content is preserved
contents = [m.content for m in loaded_session.messages]
assert "Message 0" in contents
assert "Message 1" in contents
assert "Message 2" in contents
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents

View File

@@ -14,7 +14,6 @@ import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
# Import here to allow module-level mocking if needed
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -32,7 +31,6 @@ async def test_parallel_tool_calls_run_concurrently():
for i in range(n_tools)
]
# Minimal session mock
class FakeSession:
session_id = "test"
user_id = "test"
@@ -42,7 +40,7 @@ async def test_parallel_tool_calls_run_concurrently():
original_yield = None
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
@@ -101,7 +99,7 @@ async def test_single_tool_call_works():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
@@ -144,7 +142,7 @@ async def test_retryable_error_propagates():
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
@@ -175,8 +173,8 @@ async def test_retryable_error_propagates():
@pytest.mark.asyncio
async def test_session_lock_shared():
"""All parallel tools should receive the same lock instance."""
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
@@ -199,10 +197,10 @@ async def test_session_lock_shared():
def __init__(self):
self.messages = []
observed_locks = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess, lock=None):
observed_locks.append(lock)
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
@@ -222,9 +220,8 @@ async def test_session_lock_shared():
finally:
svc._yield_tool_call = orig
assert len(observed_locks) == 3
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
assert isinstance(observed_locks[0], asyncio.Lock)
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
@@ -251,7 +248,7 @@ async def test_cancellation_cleans_up():
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess, lock=None):
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)

View File

@@ -5,6 +5,8 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
@@ -12,6 +14,8 @@ from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -47,7 +51,8 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
return f"data: {self.model_dump_json()}\n\n"
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
# ========== Message Lifecycle ==========
@@ -58,15 +63,13 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
sessionId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
description="Session ID for SSE reconnection.",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
@@ -163,8 +166,6 @@ class StreamToolOutputAvailable(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,

View File

@@ -0,0 +1,57 @@
"""Dummy SDK service for testing copilot streaming.
Returns mock streaming responses without calling Claude Agent SDK.
Enable via COPILOT_TEST_MODE=true environment variable.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
logger = logging.getLogger(__name__)
async def stream_chat_completion_dummy(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.
Returns a simple streaming response with text deltas to test:
- Streaming infrastructure works
- No timeout occurs
- Text arrives in chunks
- StreamFinish is sent by mark_session_completed
"""
logger.warning(
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
# Start the stream
yield StreamStart(messageId=message_id, sessionId=session_id)
# Simulate streaming text response with delays
dummy_response = "I counted: 1... 2... 3. All done!"
words = dummy_response.split()
for i, word in enumerate(words):
# Add space except for last word
text = word if i == len(words) - 1 else f"{word} "
yield StreamTextDelta(id=text_block_id, delta=text)
# Small delay to simulate real streaming
await asyncio.sleep(0.1)

View File

@@ -55,13 +55,8 @@ class SDKResponseAdapter:
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
@property
def has_unresolved_tool_calls(self) -> bool:
"""True when there are tool calls that haven't received output yet."""
@@ -74,7 +69,7 @@ class SDKResponseAdapter:
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
StreamStart(messageId=self.message_id, sessionId=self.session_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())

View File

@@ -37,9 +37,7 @@ from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
# -- SystemMessage -----------------------------------------------------------
@@ -51,7 +49,7 @@ def test_system_init_emits_start_and_step():
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert results[0].sessionId == "session-1"
assert isinstance(results[1], StreamStartStep)

View File

@@ -160,7 +160,7 @@ def create_security_hooks(
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
@@ -172,8 +172,9 @@ def create_security_hooks(
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
# Per-session tracking for Task sub-agent concurrency.
# Set of tool_use_ids that consumed a slot — len() is the active count.
task_tool_use_ids: set[str] = set()
async def pre_tool_use_hook(
input_data: HookInput,
@@ -181,7 +182,6 @@ def create_security_hooks(
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
@@ -200,18 +200,18 @@ def create_security_hooks(
"(remove the run_in_background parameter)."
),
)
if task_spawn_count >= max_subtasks:
if len(task_tool_use_ids) >= max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
f"Maximum {max_subtasks} concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
),
)
task_spawn_count += 1
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
@@ -229,9 +229,24 @@ def create_security_hooks(
if result:
return cast(SyncHookJSONOutput, result)
# Reserve the Task slot only after all validations pass
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a Task concurrency slot if one was reserved."""
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
task_tool_use_ids.discard(tool_use_id)
logger.info(
"[SDK] Task slot released, active=%d/%d, user=%s",
len(task_tool_use_ids),
max_subtasks,
user_id,
)
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
@@ -246,6 +261,8 @@ def create_security_hooks(
"""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
_release_task_slot(tool_name, tool_use_id)
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
@@ -289,6 +306,9 @@ def create_security_hooks(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
_release_task_slot(tool_name, tool_use_id)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(

View File

@@ -208,19 +208,22 @@ def test_bash_builtin_blocked_message_clarity():
@pytest.fixture()
def _hooks():
"""Create security hooks and return the PreToolUse handler."""
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
return pre
post = hooks["PostToolUse"][0].hooks[0]
post_failure = hooks["PostToolUseFailure"][0].hooks[0]
return pre, post, post_failure
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_background_blocked(_hooks):
"""Task with run_in_background=true must be denied."""
result = await _hooks(
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
tool_use_id=None,
context={},
@@ -233,9 +236,10 @@ async def test_task_background_blocked(_hooks):
@pytest.mark.asyncio
async def test_task_foreground_allowed(_hooks):
"""Task without run_in_background should be allowed."""
result = await _hooks(
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
tool_use_id=None,
tool_use_id="tu-1",
context={},
)
assert not _is_denied(result)
@@ -245,25 +249,102 @@ async def test_task_foreground_allowed(_hooks):
@pytest.mark.asyncio
async def test_task_limit_enforced(_hooks):
"""Task spawns beyond max_subtasks should be denied."""
pre, _, _ = _hooks
# First two should pass
for _ in range(2):
result = await _hooks(
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=None,
tool_use_id=f"tu-limit-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied (limit=2)
result = await _hooks(
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
tool_use_id=None,
tool_use_id="tu-limit-2",
context={},
)
assert _is_denied(result)
assert "Maximum" in _reason(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_completion(_hooks):
"""Completing a Task should free a slot so new Tasks can be spawned."""
pre, post, _ = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-comp-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied — at capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-comp-2",
context={},
)
assert _is_denied(result)
# Complete first task — frees a slot
await post(
{"tool_name": "Task", "tool_input": {}},
tool_use_id="tu-comp-0",
context={},
)
# Now a new Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after release"}},
tool_use_id="tu-comp-3",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_failure(_hooks):
"""A failed Task should also free its concurrency slot."""
pre, _, post_failure = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-fail-{i}",
context={},
)
assert not _is_denied(result)
# At capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-fail-2",
context={},
)
assert _is_denied(result)
# Fail first task — should free a slot
await post_failure(
{"tool_name": "Task", "tool_input": {}, "error": "something broke"},
tool_use_id="tu-fail-0",
context={},
)
# New Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after failure"}},
tool_use_id="tu-fail-3",
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
@@ -298,7 +379,9 @@ class TestIsToolErrorOrDenial:
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 sub-tasks per session. Please continue in the main conversation."
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)

View File

@@ -7,12 +7,12 @@ import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from typing import Any, cast
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
@@ -25,19 +25,13 @@ from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..service import _build_system_prompt, _generate_session_title
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
@@ -45,7 +39,6 @@ from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
wait_for_stash,
@@ -62,6 +55,7 @@ from .transcript import (
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@@ -81,13 +75,21 @@ class CapturedTranscript:
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Special message prefixes for text-based markers (parsed by frontend)
COPILOT_ERROR_PREFIX = "[COPILOT_ERROR]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[COPILOT_SYSTEM]" # Renders as system info message
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
_HEARTBEAT_INTERVAL = 15.0 # seconds
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
_HEARTBEAT_INTERVAL = 10.0 # seconds
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SDK_TOOL_SUPPLEMENT = """
def _build_sdk_tool_supplement(cwd: str) -> str:
"""Build the SDK tool supplement with the actual working directory injected."""
return f"""
## Tool notes
@@ -95,9 +97,16 @@ _SDK_TOOL_SUPPLEMENT = """
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
### Working directory
- Your working directory is: `{cwd}`
- All SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec` operate inside this
directory. This is the ONLY writable path — do not attempt to read or write
anywhere else on the filesystem.
- Use relative paths or absolute paths under `{cwd}` for all file operations.
### Two storage systems — CRITICAL to understand
1. **Ephemeral working directory** (`/tmp/copilot-<session>/`):
1. **Ephemeral working directory** (`{cwd}`):
- Shared by SDK Read/Write/Edit/Glob/Grep tools AND `bash_exec`
- Files here are **lost between turns** — do NOT rely on them persisting
- Use for temporary work: running scripts, processing data, etc.
@@ -123,6 +132,21 @@ When you create or modify important files (code, configs, outputs), you MUST:
2. At the start of a new turn, call `list_workspace_files` to see what files
are available from previous turns
### Sharing files with the user
After saving a file to the persistent workspace with `write_workspace_file`,
share it with the user by embedding the `download_url` from the response in
your message as a Markdown link or image:
- **Any file** — shows as a clickable download link:
`[report.csv](workspace://file_id#text/csv)`
- **Image** — renders inline in chat:
`![chart](workspace://file_id#image/png)`
- **Video** — renders inline in chat with player controls:
`![recording](workspace://file_id#video/mp4)`
The `download_url` field in the `write_workspace_file` response is already
in the correct format — paste it directly after the `(` in the Markdown.
### Long-running tools
Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
@@ -133,187 +157,8 @@ is delivered to the user via a background stream.
All tasks must run in the foreground.
"""
# Session streaming lock configuration
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
"""Acquire an exclusive lock for streaming to this session.
Prevents multiple concurrent streams to the same session which can cause:
- Message duplication
- Race conditions in message saves
- Confusing UX with multiple AI responses
Returns:
True if lock was acquired, False if another stream is active.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# SET NX EX - atomic "set if not exists" with expiry
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
return result is not None
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
"""Release the stream lock if we still own it.
Only releases the lock if the stored stream_id matches ours (prevents
releasing another stream's lock if we somehow timed out).
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# Lua script for atomic compare-and-delete (only delete if value matches)
script = """
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
"""
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
async def check_active_stream(session_id: str) -> str | None:
"""Check if a stream is currently active for this session.
Returns:
The active stream_id if one exists, None otherwise.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
active_stream = await redis.get(lock_key)
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
def _build_long_running_callback(
user_id: str | None,
saved_msg_count_ref: list[int] | None = None,
) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
saved_msg_count_ref: Mutable reference [count] shared with streaming loop
for coordinating message saves. When provided, the callback will update
it after appending messages to prevent counter drift.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
# Collision detection happens in add_chat_messages_batch (db.py)
_, final_count = await upsert_chat_session(session)
# Update shared counter so streaming loop stays in sync
if saved_msg_count_ref is not None:
saved_msg_count_ref[0] = final_count
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
@@ -595,6 +440,23 @@ async def stream_chat_completion_sdk(
f"Session {session_id} not found. Please create a new session first."
)
# Type narrowing: session is guaranteed ChatSession after the check above
session = cast(ChatSession, session)
# Clean up stale error markers from previous turn before starting new turn
# If the last message contains an error marker, remove it (user is retrying)
if (
len(session.messages) > 0
and session.messages[-1].role == "assistant"
and session.messages[-1].content
and COPILOT_ERROR_PREFIX in session.messages[-1].content
):
logger.info(
"[SDK] [%s] Removing stale error marker from previous turn",
session_id[:12],
)
session.messages.pop()
# Append the new message to the session if it's not already there
new_message_role = "user" if is_user_message else "assistant"
if message and (
@@ -610,7 +472,7 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
@@ -624,59 +486,61 @@ async def stream_chat_completion_sdk(
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
# Acquire stream lock to prevent concurrent streams to the same session
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
if not lock_acquired:
# Another stream is active - check if it's still alive
active_stream = await check_active_stream(session_id)
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
"Please wait for it to complete or refresh the page.",
code="stream_already_active",
)
yield StreamFinish()
return
yield StreamStart(messageId=message_id, taskId=task_id)
stream_id = str(uuid.uuid4())
stream_completed = False
# Initialise variables before the try so the finally block can
# always attempt transcript upload regardless of errors.
sdk_cwd = ""
use_resume = False
resume_file: str | None = None
captured_transcript = CapturedTranscript()
sdk_cwd = ""
try:
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
# Acquire stream lock to prevent concurrent streams to the same session
lock = AsyncClusterLock(
redis=await get_redis_async(),
key=f"{STREAM_LOCK_PREFIX}{session_id}",
owner_id=stream_id,
timeout=config.stream_lock_ttl,
)
# Initialize saved message counter as mutable list so long-running
# callback and streaming loop can coordinate
saved_msg_count_ref: list[int] = [len(session.messages)]
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(
user_id, saved_msg_count_ref
),
lock_owner = await lock.try_acquire()
if lock_owner != stream_id:
# Another stream is active
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {lock_owner}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
"Please wait or stop it.",
code="stream_already_active",
)
return
# Make sure there is no more code between the lock acquitition and try-block.
try:
# Build system prompt (reuses non-SDK path with Langfuse support).
# Pre-compute the cwd here so the exact working directory path can be
# injected into the supplement instead of the generic placeholder.
# Catch ValueError early so the failure yields a clean StreamError rather
# than propagating outside the stream error-handling path.
has_history = len(session.messages) > 1
try:
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
except (ValueError, OSError) as e:
logger.error("[SDK] [%s] Invalid SDK cwd: %s", session_id[:12], e)
yield StreamError(
errorText="Unable to initialize working directory.",
code="sdk_cwd_error",
)
return
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _build_sdk_tool_supplement(sdk_cwd)
yield StreamStart(messageId=message_id, sessionId=session_id)
set_execution_context(user_id, session)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -768,7 +632,6 @@ async def stream_chat_completion_sdk(
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
@@ -782,7 +645,6 @@ async def stream_chat_completion_sdk(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
query_message = await _build_query_message(
@@ -793,8 +655,7 @@ async def stream_chat_completion_sdk(
session_id,
)
logger.info(
"[SDK] [%s] Sending query — resume=%s, "
"total_msgs=%d, query_len=%d",
"[SDK] [%s] Sending query — resume=%s, total_msgs=%d, query_len=%d",
session_id[:12],
use_resume,
len(session.messages),
@@ -806,8 +667,6 @@ async def stream_chat_completion_sdk(
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
# Track persisted message count. Uses shared ref so long-running
# callback can update it for coordination
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
@@ -834,6 +693,8 @@ async def stream_chat_completion_sdk(
if not done:
# Timeout — emit heartbeat but keep the task alive
# Also refresh lock TTL to keep it alive
await lock.refresh()
yield StreamHeartbeat()
continue
@@ -843,8 +704,7 @@ async def stream_chat_completion_sdk(
sdk_msg = done.pop().result()
except StopAsyncIteration:
logger.info(
"[SDK] [%s] Stream ended normally "
"(StopAsyncIteration)",
"[SDK] [%s] Stream ended normally (StopAsyncIteration)",
session_id[:12],
)
break
@@ -917,6 +777,25 @@ async def stream_chat_completion_sdk(
- len(adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
if isinstance(sdk_msg, ResultMessage):
logger.info(
"[SDK] [%s] Received: ResultMessage %s "
"(unresolved=%d, current=%d, resolved=%d)",
session_id[:12],
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
)
if sdk_msg.subtype in ("error", "error_during_execution"):
logger.error(
"[SDK] [%s] SDK execution failed with error: %s",
session_id[:12],
sdk_msg.result or "(no error message provided)",
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
@@ -941,6 +820,15 @@ async def stream_chat_completion_sdk(
extra,
)
# Log errors being sent to frontend
if isinstance(response, StreamError):
logger.error(
"[SDK] [%s] Sending error to frontend: %s (code=%s)",
session_id[:12],
response.errorText,
response.code,
)
yield response
if isinstance(response, StreamTextDelta):
@@ -981,20 +869,6 @@ async def stream_chat_completion_sdk(
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
@@ -1009,19 +883,6 @@ async def stream_chat_completion_sdk(
)
)
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
session_id[:12],
save_err,
)
elif isinstance(response, StreamFinish):
stream_completed = True
@@ -1031,8 +892,7 @@ async def stream_chat_completion_sdk(
# server shutdown). Log and let the safety-net / finally
# blocks handle cleanup.
logger.warning(
"[SDK] [%s] Streaming loop cancelled "
"(asyncio.CancelledError)",
"[SDK] [%s] Streaming loop cancelled (asyncio.CancelledError)",
session_id[:12],
)
raise
@@ -1074,25 +934,29 @@ async def stream_chat_completion_sdk(
)
yield response
# If the stream ended without a ResultMessage (no
# StreamFinish), the SDK CLI exited unexpectedly. Close
# the open step and emit StreamFinish so the frontend
# transitions to the "ready" state.
# If the stream ended without a ResultMessage, the SDK
# CLI exited unexpectedly or the user stopped execution.
# Close any open text/step so chunks are well-formed, and
# append a cancellation message so users see feedback.
# StreamFinish is published by mark_session_completed in the processor.
if not stream_completed:
logger.warning(
"[SDK] [%s] Stream ended without ResultMessage "
"(StopAsyncIteration) — emitting StreamFinish",
logger.info(
"[SDK] [%s] Stream ended without ResultMessage (stopped by user)",
session_id[:12],
)
if adapter.step_open:
yield StreamFinishStep()
adapter.step_open = False
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
yield StreamFinish()
stream_completed = True
# Add "Stopped by user" message so it persists after refresh
# Use COPILOT_SYSTEM_PREFIX so frontend renders it as system message, not assistant
session.messages.append(
ChatMessage(
role="assistant",
content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user",
)
)
if (
assistant_response.content or assistant_response.tool_calls
@@ -1112,7 +976,7 @@ async def stream_chat_completion_sdk(
elif captured_transcript.path:
raw_transcript = read_transcript_file(captured_transcript.path)
logger.debug(
"[SDK] Transcript source: stop hook (%s), " "read result: %s",
"[SDK] Transcript source: stop hook (%s), read result: %s",
captured_transcript.path,
f"{len(raw_transcript)}B" if raw_transcript else "None",
)
@@ -1147,34 +1011,76 @@ async def stream_chat_completion_sdk(
"to use the OpenAI-compatible fallback."
)
_, final_count = await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
"[SDK] [%s] Stream completed successfully with %d messages",
session_id[:12],
len(session.messages),
final_count,
)
if not stream_completed:
yield StreamFinish()
except BaseException as e:
# Catch BaseException to handle both Exception and CancelledError
# (CancelledError inherits from BaseException in Python 3.8+)
if isinstance(e, asyncio.CancelledError):
logger.warning("[SDK] [%s] Session cancelled", session_id[:12])
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
# SDK cleanup RuntimeError is expected during cancellation, log as warning
if isinstance(e, RuntimeError) and "cancel scope" in str(e):
logger.warning(
"[SDK] [%s] SDK cleanup error: %s", session_id[:12], error_msg
)
else:
logger.error(
f"[SDK] [%s] Error: {error_msg}", session_id[:12], exc_info=True
)
except asyncio.CancelledError:
# Client disconnect / server shutdown — log but re-raise so
# the framework can clean up. The finally block still runs
# for transcript upload.
logger.warning("[SDK] [%s] Session cancelled (CancelledError)", session_id[:12])
raise
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await asyncio.shield(upsert_chat_session(session))
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
# Append error marker to session (non-invasive text parsing approach)
# The finally block will persist the session with this error marker
if session:
session.messages.append(
ChatMessage(
role="assistant", content=f"{COPILOT_ERROR_PREFIX} {error_msg}"
)
)
logger.debug(
"[SDK] [%s] Appended error marker, will be persisted in finally",
session_id[:12],
)
# Yield StreamError for immediate feedback (only for non-cancellation errors)
# Skip for CancelledError and RuntimeError cleanup issues (both are cancellations)
is_cancellation = isinstance(e, asyncio.CancelledError) or (
isinstance(e, RuntimeError) and "cancel scope" in str(e)
)
yield StreamFinish()
if not is_cancellation:
yield StreamError(
errorText=error_msg,
code="sdk_error",
)
raise
finally:
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator
# is stopped early (e.g., user clicks stop, processor breaks stream loop).
# Without this, messages disappear after refresh because they were never
# saved to the database.
if session is not None:
try:
await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session persisted in finally with %d messages",
session_id[:12],
len(session.messages),
)
except Exception as persist_err:
logger.error(
"[SDK] [%s] Failed to persist session in finally: %s",
session_id[:12],
persist_err,
exc_info=True,
)
# --- Upload transcript for next-turn --resume ---
# This MUST run in finally so the transcript is uploaded even when
# the streaming loop raises an exception. The CLI uses
@@ -1190,7 +1096,7 @@ async def stream_chat_completion_sdk(
if not raw_transcript and use_resume and resume_file:
raw_transcript = read_transcript_file(resume_file)
if raw_transcript:
if raw_transcript and session is not None:
await asyncio.shield(
_try_upload_transcript(
user_id,
@@ -1211,7 +1117,7 @@ async def stream_chat_completion_sdk(
_cleanup_sdk_tool_results(sdk_cwd)
# Release stream lock to allow new streams for this session
await _release_stream_lock(session_id, stream_id)
await lock.release()
async def _try_upload_transcript(

View File

@@ -2,11 +2,6 @@
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
@@ -15,7 +10,6 @@ import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
@@ -43,7 +37,8 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
@@ -54,22 +49,10 @@ _stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
@@ -79,14 +62,11 @@ def set_execution_context(
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
@@ -276,11 +256,6 @@ def create_tool_handler(base_tool: BaseTool):
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
@@ -290,25 +265,6 @@ def create_tool_handler(base_tool: BaseTool):
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,7 @@ import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
@@ -30,7 +25,6 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
@@ -40,10 +34,9 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@@ -58,10 +51,9 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
@@ -71,13 +63,9 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
@@ -104,7 +92,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session, _ = await upsert_chat_session(session)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"
@@ -114,7 +102,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -125,10 +112,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
@@ -159,7 +143,6 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -171,10 +154,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,401 @@
"""End-to-end tests for Copilot streaming with dummy implementations.
These tests verify the complete copilot flow using dummy implementations
for agent generator and SDK service, allowing automated testing without
external LLM calls.
Enable test mode with COPILOT_TEST_MODE=true environment variable.
Note: StreamFinish is NOT emitted by the dummy service — it is published
by mark_session_completed in the processor layer. These tests only cover
the service-level streaming output (StreamStart + StreamTextDelta).
"""
import asyncio
import os
from uuid import uuid4
import pytest
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
from backend.copilot.response_model import (
StreamError,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
)
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@pytest.fixture(autouse=True)
def enable_test_mode():
"""Enable test mode for all tests in this module."""
os.environ["COPILOT_TEST_MODE"] = "true"
yield
os.environ.pop("COPILOT_TEST_MODE", None)
@pytest.mark.asyncio
async def test_dummy_streaming_basic_flow():
"""Test that dummy streaming produces correct event sequence."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-basic",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Verify we got events
assert len(events) > 0, "Should receive events"
# Verify StreamStart
start_events = [e for e in events if isinstance(e, StreamStart)]
assert len(start_events) == 1
assert start_events[0].messageId
assert start_events[0].sessionId
# Verify StreamTextDelta events
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
assert len(text_events) > 0
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0
# Verify order: start before text
start_idx = events.index(start_events[0])
first_text_idx = events.index(text_events[0]) if text_events else -1
if first_text_idx >= 0:
assert start_idx < first_text_idx
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
@pytest.mark.asyncio
async def test_streaming_no_timeout():
"""Test that streaming completes within reasonable time without timeout."""
import time
start_time = time.monotonic()
event_count = 0
async for _event in stream_chat_completion_dummy(
session_id="test-session-timeout",
message="count to 10",
is_user_message=True,
user_id="test-user",
):
event_count += 1
elapsed = time.monotonic() - start_time
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
assert event_count > 0, "Should receive events"
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
@pytest.mark.asyncio
async def test_streaming_event_types():
"""Test that all expected event types are present."""
event_types = set()
async for event in stream_chat_completion_dummy(
session_id="test-session-types",
message="test",
is_user_message=True,
user_id="test-user",
):
event_types.add(type(event).__name__)
# Required event types (StreamFinish is published by processor, not service)
assert "StreamStart" in event_types, "Missing StreamStart"
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
print(f"✅ Event types: {sorted(event_types)}")
@pytest.mark.asyncio
async def test_streaming_text_content():
"""Test that streamed text is coherent and complete."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-content",
message="count to 3",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify text deltas
assert len(text_events) > 0, "Should have text deltas"
# Reconstruct full text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Text should not be empty"
assert (
"1" in full_text or "counted" in full_text.lower()
), "Text should contain count"
# Verify all deltas have IDs
for text_event in text_events:
assert text_event.id, "Text delta must have ID"
assert text_event.delta, "Text delta must have content"
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
@pytest.mark.asyncio
async def test_streaming_heartbeat_timing():
"""Test that heartbeats are sent at correct interval during long operations."""
# This test would need a dummy that takes longer
# For now, just verify heartbeat structure if we receive one
heartbeats = []
async for event in stream_chat_completion_dummy(
session_id="test-session-heartbeat",
message="test",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamHeartbeat):
heartbeats.append(event)
# Dummy is fast, so we might not get heartbeats
# But if we do, verify they're valid
if heartbeats:
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
else:
print("✅ No heartbeats (dummy executes quickly)")
@pytest.mark.asyncio
async def test_error_handling():
"""Test that errors are properly formatted and sent."""
# This would require a dummy that can trigger errors
# For now, just verify error event structure
error = StreamError(errorText="Test error", code="test_error")
assert error.errorText == "Test error"
assert error.code == "test_error"
assert str(error.type.value) in ["error", "error"]
print("✅ Error structure verified")
@pytest.mark.asyncio
async def test_concurrent_sessions():
"""Test that multiple sessions can stream concurrently."""
async def stream_session(session_id: str) -> int:
count = 0
async for _event in stream_chat_completion_dummy(
session_id=session_id,
message="test",
is_user_message=True,
user_id="test-user",
):
count += 1
return count
# Run 3 concurrent sessions
results = await asyncio.gather(
stream_session("session-1"),
stream_session("session-2"),
stream_session("session-3"),
)
# All should complete successfully
assert all(count > 0 for count in results), "All sessions should produce events"
print(f"✅ Concurrent sessions: {results} events each")
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
)
async def test_session_state_persistence():
"""Test that session state is maintained across multiple messages."""
from datetime import datetime, timezone
session_id = f"test-session-{uuid4()}"
user_id = "test-user"
# Create session with first message
session = ChatSession(
session_id=session_id,
user_id=user_id,
messages=[
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
],
usage=[],
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
await upsert_chat_session(session)
# Stream second message
events = []
async for event in stream_chat_completion_dummy(
session_id=session_id,
message="How are you?",
is_user_message=True,
user_id=user_id,
session=session, # Pass existing session
):
events.append(event)
# Verify events were produced
assert len(events) > 0, "Should produce events for second message"
print(f"✅ Session persistence: {len(events)} events for second message")
@pytest.mark.asyncio
async def test_message_deduplication():
"""Test that duplicate messages are filtered out."""
# Simulate receiving duplicate events (e.g., from reconnection)
events = []
# First stream
async for event in stream_chat_completion_dummy(
session_id="test-dedup-1",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Count unique message IDs in StreamStart events
start_events = [e for e in events if isinstance(e, StreamStart)]
message_ids = [e.messageId for e in start_events]
# Verify all IDs are present
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
print(f"✅ Deduplication: {len(events)} events, all unique")
@pytest.mark.asyncio
async def test_event_ordering():
"""Test that events arrive in correct order."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-ordering",
message="Test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Find event indices
start_idx = next(
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
)
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
# Verify ordering
assert start_idx is not None, "Should have StreamStart"
assert start_idx == 0, "StreamStart should be first"
if text_indices:
assert all(
start_idx < i for i in text_indices
), "Text deltas should be after start"
print(f"✅ Event ordering: start({start_idx}) < text deltas")
@pytest.mark.asyncio
async def test_stream_completeness():
"""Test that stream includes all required event types."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-completeness",
message="Complete stream test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Check for required events (StreamFinish is published by processor)
has_start = any(isinstance(e, StreamStart) for e in events)
has_text = any(isinstance(e, StreamTextDelta) for e in events)
assert has_start, "Stream must include StreamStart"
assert has_text, "Stream must include text deltas"
# Verify exactly one start
start_count = sum(1 for e in events if isinstance(e, StreamStart))
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
print(
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
)
@pytest.mark.asyncio
async def test_text_delta_consistency():
"""Test that text deltas have consistent IDs and build coherent text."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-consistency",
message="Test consistency",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify all text deltas have IDs
assert all(e.id for e in text_events), "All text deltas must have IDs"
# Verify all deltas have the same ID (same text block)
if text_events:
first_id = text_events[0].id
assert all(
e.id == first_id for e in text_events
), "All text deltas should share the same block ID"
# Verify deltas build coherent text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Deltas should build non-empty text"
assert (
full_text == full_text.strip()
), "Text should not have leading/trailing whitespace artifacts"
print(
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
)
if __name__ == "__main__":
# Run tests directly
print("Running Copilot E2E tests with dummy implementations...")
print("=" * 60)
asyncio.run(test_dummy_streaming_basic_flow())
asyncio.run(test_streaming_no_timeout())
asyncio.run(test_streaming_event_types())
asyncio.run(test_streaming_text_content())
asyncio.run(test_streaming_heartbeat_timing())
asyncio.run(test_error_handling())
asyncio.run(test_concurrent_sessions())
asyncio.run(test_session_state_persistence())
asyncio.run(test_message_deduplication())
asyncio.run(test_event_ordering())
asyncio.run(test_stream_completeness())
asyncio.run(test_text_delta_consistency())
print("=" * 60)
print("✅ All E2E tests passed!")

View File

@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -47,7 +46,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv
import pytest
import pytest_asyncio
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
@@ -31,14 +32,16 @@ def make_session(user_id: str):
)
@pytest.fixture(scope="session")
async def setup_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_test_data(server):
"""
Set up test data for run_agent tests:
1. Create a test user
2. Create a test graph (agent input -> agent output)
3. Create a store listing and store listing version
4. Approve the store listing version
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {
@@ -150,14 +153,16 @@ async def setup_test_data():
}
@pytest.fixture(scope="session")
async def setup_llm_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_llm_test_data(server):
"""
Set up test data for LLM agent tests:
1. Create a test user
2. Create test OpenAI credentials for the user
3. Create a test graph with input -> LLM block -> output
4. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
key = getenv("OPENAI_API_KEY")
if not key:
@@ -315,13 +320,15 @@ async def setup_llm_test_data():
}
@pytest.fixture(scope="session")
async def setup_firecrawl_test_data():
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_firecrawl_test_data(server):
"""
Set up test data for Firecrawl agent tests (missing credentials scenario):
1. Create a test user (WITHOUT Firecrawl credentials)
2. Create a test graph with input -> Firecrawl block -> output
3. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {

View File

@@ -19,6 +19,7 @@ from .core import (
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_by_ids,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
@@ -49,6 +50,7 @@ __all__ = [
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_by_ids",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",

View File

@@ -3,6 +3,7 @@
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
@@ -78,7 +79,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
@@ -190,6 +191,36 @@ async def get_library_agent_by_id(
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_by_ids(
user_id: str,
agent_ids: list[str],
) -> list[LibraryAgentSummary]:
"""Fetch multiple library agents by their IDs.
Args:
user_id: The user ID
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
Returns:
List of LibraryAgentSummary for found agents (silently skips not found)
"""
agents: list[LibraryAgentSummary] = []
for agent_id in agent_ids:
try:
agent = await get_library_agent_by_id(user_id, agent_id)
if agent:
agents.append(agent)
logger.debug(f"Fetched library agent by ID: {agent['name']}")
else:
logger.warning(f"Library agent not found for ID: {agent_id}")
except Exception as e:
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
continue
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
return agents
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
@@ -214,10 +245,17 @@ async def get_library_agents_for_generation(
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
search_term = search_query.strip() if search_query else None
if search_term and len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await library_db().list_library_agents(
user_id=user_id,
search_term=search_query,
search_term=search_term,
page=1,
page_size=max_results,
include_executions=True,
@@ -271,9 +309,16 @@ async def search_marketplace_agents_for_generation(
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
search_term = search_query.strip()
if len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await store_db().get_store_agents(
search_query=search_query,
search_query=search_term,
page=1,
page_size=max_results,
)
@@ -424,7 +469,7 @@ def extract_search_terms_from_steps(
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
@@ -448,7 +493,7 @@ async def enrich_library_agents_from_steps(
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
return list(existing_agents)
existing_ids: set[str] = set()
existing_names: set[str] = set()
@@ -511,7 +556,7 @@ async def enrich_library_agents_from_steps(
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[AgentSummary] | None = None,
library_agents: Sequence[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
@@ -539,22 +584,16 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -562,13 +601,9 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
dict(instructions), _to_dict_list(library_agents)
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -758,9 +793,7 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
library_agents: Sequence[AgentSummary] | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -773,12 +806,10 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -789,8 +820,6 @@ async def generate_agent_patch(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -102,10 +102,15 @@ async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent JSON after a simulated delay."""
logger.info("Using dummy agent generator for generate_agent (30s delay)")
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
await asyncio.sleep(30)
return _generate_dummy_agent_json()
@@ -115,10 +120,16 @@ async def generate_agent_patch_dummy(
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent (returns the current agent with updated description)."""
logger.info("Using dummy agent generator for generate_agent_patch")
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"

View File

@@ -242,24 +242,18 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
Agent JSON dict or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
return await generate_agent_dummy(instructions, library_agents)
client = _get_client()
@@ -267,25 +261,9 @@ async def generate_agent_external(
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -317,8 +295,6 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -327,14 +303,14 @@ async def generate_agent_patch_external(
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents, operation_id, task_id
update_request, current_agent, library_agents
)
client = _get_client()
@@ -346,25 +322,9 @@ async def generate_agent_patch_external(
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -419,6 +379,8 @@ async def customize_template_external(
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error

View File

@@ -5,7 +5,7 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
@@ -13,6 +13,7 @@ from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -33,6 +34,7 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -116,6 +118,11 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -145,6 +152,13 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -224,10 +238,14 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
@@ -242,11 +260,25 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Get completed executions with time filters
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
)
# Get executions with time filters
executions = await exec_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
statuses=statuses,
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -313,10 +345,33 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
message = f"Found execution outputs for agent '{agent.name}'"
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status == ExecutionStatus.REVIEW:
message = (
f"Execution for agent '{agent.name}' is awaiting human review. "
"The user needs to approve it before it can continue."
)
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
if len(available_executions) > 1:
message += (
f". Showing latest of {len(available_executions)} matching executions."
f" Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -431,13 +486,17 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Fetch execution(s)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -446,4 +505,17 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -1,8 +1,13 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
from __future__ import annotations
import logging
import re
from typing import Literal
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from backend.data.db_accessors import library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -24,94 +29,24 @@ _UUID_PATTERN = re.compile(
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
# Keywords that should be treated as "list all" rather than a literal search
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None,
session_id: str | None = None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
For library searches, keywords like "all", "*", "everything", or an empty
query will list all agents without filtering.
Args:
query: Search query string
query: Search query string. Special keywords list all library agents.
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
@@ -119,7 +54,11 @@ async def search_agents(
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
if not query:
# Normalize list-all keywords to empty string for library searches
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
query = ""
if source == "marketplace" and not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
@@ -159,28 +98,18 @@ async def search_agents(
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
search_term = query or None
logger.info(
f"{'Listing all agents in' if not query else 'Searching'} "
f"user library{'' if not query else f' for: {query}'}"
)
results = await library_db().list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
search_term=search_term,
page_size=50 if not query else 10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
agents.append(_library_agent_to_info(agent))
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
@@ -193,42 +122,62 @@ async def search_agents(
)
if not agents:
suggestions = (
[
if source == "marketplace":
suggestions = [
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
if source == "marketplace"
else [
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can "
"try different keywords or browse the marketplace. Also let them "
"know you can create a custom agent for them based on their needs."
)
elif not query:
# User asked to list all but library is empty
suggestions = [
"Browse the marketplace to find and add agents",
"Use find_agent to search the marketplace",
]
no_results_msg = (
"Your library is empty. Let the user know they can browse the "
"marketplace to find agents, or you can create a custom agent "
"for them based on their needs."
)
else:
suggestions = [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
)
no_results_msg = (
f"No agents matching '{query}' found in your library. Let the "
"user know you can create a custom agent for them based on "
"their needs."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
if source == "marketplace":
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
elif not query:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
else:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents. "
"Let the user know we can create a custom agent for them based on their needs."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
else "Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
)
return AgentsFoundResponse(
@@ -238,3 +187,67 @@ async def search_agents(
count=len(agents),
session_id=session_id,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
"""Convert a library agent model to an AgentInfo."""
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by graph_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None

View File

@@ -36,16 +36,6 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -1,124 +0,0 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.copilot import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -10,7 +10,6 @@ from .agent_generator import (
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -18,7 +17,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -40,17 +38,16 @@ class CreateAgentTool(BaseTool):
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true."
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -70,6 +67,15 @@ class CreateAgentTool(BaseTool):
"Include any preferences or constraints mentioned by the user."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
),
},
"save": {
"type": "boolean",
"description": (
@@ -97,12 +103,14 @@ class CreateAgentTool(BaseTool):
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
)
if not description:
return ErrorResponse(
@@ -111,25 +119,34 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id:
if user_id and library_agent_ids:
try:
library_agents = await get_all_relevant_agents_for_generation(
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
user_id=user_id,
search_query=description,
include_marketplace=True,
agent_ids=library_agent_ids,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
logger.warning(f"Failed to fetch library agents by IDs: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -230,10 +247,17 @@ class CreateAgentTool(BaseTool):
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -276,25 +300,20 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
@@ -320,6 +339,13 @@ class CreateAgentTool(BaseTool):
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
@@ -330,6 +356,12 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",

View File

@@ -43,11 +43,6 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -78,11 +73,6 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -120,11 +110,6 @@ async def test_clarifying_questions_returns_clarification_needed_response(
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,

View File

@@ -46,10 +46,6 @@ class CustomizeAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -9,7 +9,6 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,7 +16,6 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -38,17 +36,16 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts."
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -74,6 +71,15 @@ class EditAgentTool(BaseTool):
"Additional context or answers to previous clarifying questions."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
),
},
"save": {
"type": "boolean",
"description": (
@@ -102,13 +108,10 @@ class EditAgentTool(BaseTool):
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -132,21 +135,25 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id:
if user_id and library_agent_ids:
try:
from .agent_generator import get_library_agents_by_ids
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
library_agents = await get_library_agents_by_ids(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
agent_ids=filtered_ids,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
logger.warning(f"Failed to fetch library agents by IDs: {e}")
update_request = changes
if context:
@@ -157,8 +164,6 @@ class EditAgentTool(BaseTool):
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -178,19 +183,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -0,0 +1,186 @@
"""Shared utilities for execution waiting and status handling."""
import asyncio
import logging
from typing import Any
from backend.data.db_accessors import execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
)
logger = logging.getLogger(__name__)
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
# Statuses where execution is paused but not finished (e.g. human-in-the-loop)
PAUSED_STATUSES = frozenset(
{
ExecutionStatus.REVIEW,
}
)
# Statuses that mean "stop waiting" (terminal or paused)
STOP_WAITING_STATUSES = TERMINAL_STATUSES | PAUSED_STATUSES
_POST_SUBSCRIBE_RECHECK_DELAY = 0.1 # seconds to wait for subscription to establish
async def wait_for_execution(
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> GraphExecution | None:
"""
Wait for an execution to reach a terminal or paused status using Redis pubsub.
Handles the race condition between checking status and subscribing by
re-checking the DB after the subscription is established.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
The execution with current status, or None if not found
"""
exec_db = execution_db()
# Quick check — maybe it's already done
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None
if execution.status in STOP_WAITING_STATUSES:
logger.debug(
f"Execution {execution_id} already in stop-waiting state: "
f"{execution.status}"
)
return execution
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
# Mutable container so _subscribe_and_wait can surface the task even if
# asyncio.wait_for cancels the coroutine before it returns.
task_holder: list[asyncio.Task] = []
try:
result = await asyncio.wait_for(
_subscribe_and_wait(
event_bus, channel_key, user_id, execution_id, exec_db, task_holder
),
timeout=timeout_seconds,
)
return result
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
finally:
for task in task_holder:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await event_bus.close()
# Return current state on timeout/error
return await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
async def _subscribe_and_wait(
event_bus: AsyncRedisExecutionEventBus,
channel_key: str,
user_id: str,
execution_id: str,
exec_db: Any,
task_holder: list[asyncio.Task],
) -> GraphExecution | None:
"""
Subscribe to execution events and wait for a terminal/paused status.
Appends the consumer task to ``task_holder`` so the caller can clean it up
even if this coroutine is cancelled by ``asyncio.wait_for``.
To avoid the race condition where the execution completes between the
initial DB check and the Redis subscription, we:
1. Start listening (which subscribes internally)
2. Re-check the DB after subscription is active
3. If still running, wait for pubsub events
"""
listen_iter = event_bus.listen_events(channel_key).__aiter__()
done = asyncio.Event()
result_execution: GraphExecution | None = None
async def _consume() -> None:
nonlocal result_execution
try:
async for event in listen_iter:
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in STOP_WAITING_STATUSES:
result_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
done.set()
return
except Exception as e:
logger.error(f"Error in execution consumer: {e}", exc_info=True)
done.set()
consume_task = asyncio.create_task(_consume())
task_holder.append(consume_task)
# Give the subscription a moment to establish, then re-check DB
await asyncio.sleep(_POST_SUBSCRIBE_RECHECK_DELAY)
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if execution and execution.status in STOP_WAITING_STATUSES:
return execution
# Wait for the pubsub consumer to find a terminal event
await done.wait()
return result_execution
def get_execution_outputs(execution: GraphExecution | None) -> dict[str, Any] | None:
"""Extract outputs from an execution, or return None."""
if execution is None:
return None
return execution.outputs

View File

@@ -366,12 +366,15 @@ class TestFindBlockFiltering:
return_value=(search_results, len(search_results))
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
), patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
),
):
tool = FindBlockTool()
response = await tool._execute(

View File

@@ -19,9 +19,10 @@ class FindLibraryAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"Omit the query to list all agents."
)
@property
@@ -31,10 +32,13 @@ class FindLibraryAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": "Search query to find agents by name or description.",
"description": (
"Search query to find agents by name or description. "
"Omit to list all agents in the library."
),
},
},
"required": ["query"],
"required": [],
}
@property
@@ -45,7 +49,7 @@ class FindLibraryAgentTool(BaseTool):
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=kwargs.get("query", "").strip(),
query=(kwargs.get("query") or "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -36,8 +36,6 @@ class ResponseType(str, Enum):
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
@@ -45,8 +43,6 @@ class ResponseType(str, Enum):
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
@@ -420,34 +416,6 @@ class BlockOutputResponse(ToolResponseBase):
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
@@ -459,23 +427,6 @@ class OperationInProgressResponse(ToolResponseBase):
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""

View File

@@ -9,6 +9,7 @@ from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.execution import ExecutionStatus
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
@@ -20,12 +21,15 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .execution_utils import get_execution_outputs, wait_for_execution
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
AgentOutputResponse,
ErrorResponse,
ExecutionOptions,
ExecutionOutputInfo,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
@@ -66,6 +70,7 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -147,6 +152,14 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
},
},
"required": [],
}
@@ -341,6 +354,7 @@ class RunAgentTool(BaseTool):
graph=graph,
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
)
except NotFoundError as e:
@@ -424,8 +438,9 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
) -> ToolResponseBase:
"""Execute an agent immediately."""
"""Execute an agent immediately, optionally waiting for completion."""
session_id = session.session_id
# Check rate limits
@@ -462,6 +477,91 @@ class RunAgentTool(BaseTool):
)
library_agent_link = f"/library/agents/{library_agent.id}"
# If wait_for_result is requested, wait for execution to complete
if wait_for_result > 0:
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
completed = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
if completed and completed.status == ExecutionStatus.COMPLETED:
outputs = get_execution_outputs(completed)
return AgentOutputResponse(
message=(
f"Agent '{library_agent.name}' completed successfully. "
f"View at {library_agent_link}."
),
session_id=session_id,
agent_name=library_agent.name,
agent_id=library_agent.graph_id,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
execution=ExecutionOutputInfo(
execution_id=execution.id,
status=completed.status.value,
started_at=completed.started_at,
ended_at=completed.ended_at,
outputs=outputs or {},
),
)
elif completed and completed.status == ExecutionStatus.FAILED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution failed. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.TERMINATED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution was terminated. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.REVIEW:
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"Check at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=ExecutionStatus.REVIEW.value,
)
else:
status = completed.status.value if completed else "unknown"
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is still {status} after "
f"{wait_for_result}s. Check results later at "
f"{library_agent_link}. "
f"Use view_agent_output with wait_if_running to check again."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=status,
)
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' execution started successfully. "

View File

@@ -160,9 +160,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
# Get block schemas for details/validation
try:

View File

@@ -214,7 +214,11 @@ class WorkspaceWriteResponse(ToolResponseBase):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
# workspace:// URL the agent can embed directly in chat to give the user a link.
# Format: workspace://<file_id>#<mime_type> (frontend resolves to download URL)
download_url: str
source: str | None = None # "content", "base64", or "copied from <path>"
content_preview: str | None = None # First 200 chars for text files
@@ -680,11 +684,21 @@ class WriteWorkspaceFileTool(BaseTool):
except Exception:
pass
# Strip MIME parameters (e.g. "text/html; charset=utf-8" → "text/html")
# and normalise to lowercase so the fragment is URL-safe.
normalized_mime = (rec.mime_type or "").split(";", 1)[0].strip().lower()
download_url = (
f"workspace://{rec.id}#{normalized_mime}"
if normalized_mime
else f"workspace://{rec.id}"
)
return WorkspaceWriteResponse(
file_id=rec.id,
name=rec.name,
path=rec.path,
mime_type=normalized_mime,
size_bytes=rec.size_bytes,
download_url=download_url,
source=source,
content_preview=preview,
message=msg,

View File

@@ -79,6 +79,12 @@ INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
}
LIBRARY_FOLDER_INCLUDE: prisma.types.LibraryFolderInclude = {
"LibraryAgents": {"where": {"isDeleted": False}},
"Children": {"where": {"isDeleted": False}},
}
def library_agent_include(
user_id: str,
include_nodes: bool = True,
@@ -105,6 +111,7 @@ def library_agent_include(
"""
result: prisma.types.LibraryAgentInclude = {
"Creator": True, # Always needed for creator info
"Folder": True, # Always needed for folder info
}
# Build AgentGraph include based on requested options

View File

@@ -184,7 +184,7 @@ async def find_webhook_by_credentials_and_props(
credentials_id: str,
webhook_type: str,
resource: str,
events: list[str],
events: Optional[list[str]],
) -> Webhook | None:
webhook = await IntegrationWebhook.prisma().find_first(
where={
@@ -192,7 +192,7 @@ async def find_webhook_by_credentials_and_props(
"credentialsId": credentials_id,
"webhookType": webhook_type,
"resource": resource,
"events": {"has_every": events},
**({"events": {"has_every": events}} if events else {}),
},
)
return Webhook.from_db(webhook) if webhook else None

View File

@@ -0,0 +1,426 @@
"""Tally form integration: cache submissions, match by email, extract business understanding."""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Optional
from openai import AsyncOpenAI
from backend.data.redis_client import get_redis_async
from backend.data.understanding import (
BusinessUnderstandingInput,
get_business_understanding,
upsert_business_understanding,
)
from backend.util.request import Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
TALLY_API_BASE = "https://api.tally.so"
_settings = Settings()
TALLY_FORM_ID = _settings.secrets.tally_form_id
# Redis key templates
_EMAIL_INDEX_KEY = "tally:form:{form_id}:email_index"
_QUESTIONS_KEY = "tally:form:{form_id}:questions"
_LAST_FETCH_KEY = "tally:form:{form_id}:last_fetch"
# TTLs — keep aligned so last_fetch never outlives the index
_INDEX_TTL = 3600 # 1 hour
_LAST_FETCH_TTL = 3600 # 1 hour (same as index)
# Pagination
_PAGE_LIMIT = 500
_MAX_PAGES = 100
# LLM extraction timeout (seconds)
_LLM_TIMEOUT = 30
def _mask_email(email: str) -> str:
"""Mask an email for safe logging: 'alice@example.com' -> 'a***e@example.com'."""
try:
local, domain = email.rsplit("@", 1)
if len(local) <= 2:
masked_local = local[0] + "***"
else:
masked_local = local[0] + "***" + local[-1]
return f"{masked_local}@{domain}"
except (ValueError, IndexError):
return "***"
async def _fetch_tally_page(
client: Requests,
form_id: str,
page: int,
limit: int = _PAGE_LIMIT,
start_date: Optional[str] = None,
) -> dict:
"""Fetch a single page of submissions from the Tally API."""
url = f"{TALLY_API_BASE}/forms/{form_id}/submissions?page={page}&limit={limit}"
if start_date:
url += f"&startDate={start_date}"
response = await client.get(url)
return response.json()
def _make_tally_client(api_key: str) -> Requests:
"""Create a Requests client configured for the Tally API."""
return Requests(
trusted_origins=[TALLY_API_BASE],
raise_for_status=True,
extra_headers={
"Authorization": f"Bearer {api_key}",
"Accept": "application/json",
},
)
async def _fetch_all_submissions(
client: Requests,
form_id: str,
start_date: Optional[str] = None,
max_pages: int = _MAX_PAGES,
) -> tuple[list[dict], list[dict]]:
"""Paginate through all Tally submissions. Returns (questions, submissions)."""
questions: list[dict] = []
all_submissions: list[dict] = []
page = 1
while True:
data = await _fetch_tally_page(client, form_id, page, start_date=start_date)
if page == 1:
questions = data.get("questions", [])
submissions = data.get("submissions", [])
all_submissions.extend(submissions)
# Tally API uses `hasMore` for pagination
has_more = data.get("hasMore", False)
if not has_more:
break
if page >= max_pages:
total = data.get("totalNumberOfSubmissionsPerFilter", {}).get("all", "?")
logger.warning(
f"Tally: hit max page cap ({max_pages}) for form {form_id}, "
f"fetched {len(all_submissions)} of {total} total submissions"
)
break
page += 1
return questions, all_submissions
def _build_email_index(
submissions: list[dict], questions: list[dict]
) -> dict[str, dict]:
"""Build an {email -> submission_data} index from submissions.
Scans question titles for email/contact fields to find the email answer.
"""
# Find question IDs that are likely email fields
email_question_ids: list[str] = []
for q in questions:
label = (q.get("label") or q.get("title") or q.get("name") or "").lower()
q_type = (q.get("type") or "").lower()
if q_type in ("input_email", "email"):
email_question_ids.append(q["id"])
elif any(kw in label for kw in ("email", "e-mail", "contact")):
email_question_ids.append(q["id"])
index: dict[str, dict] = {}
for sub in submissions:
email = _extract_email_from_submission(sub, email_question_ids)
if email:
index[email.lower()] = {
"responses": sub.get("responses", sub.get("fields", [])),
"submitted_at": sub.get("submittedAt", sub.get("createdAt", "")),
"questions": sub.get("questions", []),
}
return index
def _extract_email_from_submission(
submission: dict, email_question_ids: list[str]
) -> Optional[str]:
"""Extract email address from a submission by checking respondentEmail, then field responses."""
# Try respondent email first (Tally often includes this)
respondent_email = submission.get("respondentEmail")
if respondent_email:
return respondent_email
# Search through responses/fields for matching question IDs
responses = submission.get("responses", submission.get("fields", []))
if isinstance(responses, list):
for resp in responses:
q_id = resp.get("questionId") or resp.get("key") or resp.get("id")
if q_id in email_question_ids:
value = resp.get("value") or resp.get("answer")
if isinstance(value, str) and "@" in value:
return value
elif isinstance(responses, dict):
for q_id in email_question_ids:
value = responses.get(q_id)
if isinstance(value, str) and "@" in value:
return value
return None
async def _get_cached_index(
form_id: str,
) -> tuple[Optional[dict], Optional[list]]:
"""Return (email_index, questions) from Redis, or (None, None) on cache miss."""
redis = await get_redis_async()
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
raw_index = await redis.get(index_key)
raw_questions = await redis.get(questions_key)
if raw_index and raw_questions:
return json.loads(raw_index), json.loads(raw_questions)
return None, None
async def _refresh_cache(form_id: str) -> tuple[dict, list]:
"""Refresh the Tally submission cache. Uses incremental fetch when possible.
Returns (email_index, questions).
"""
settings = Settings()
client = _make_tally_client(settings.secrets.tally_api_key)
redis = await get_redis_async()
last_fetch_key = _LAST_FETCH_KEY.format(form_id=form_id)
index_key = _EMAIL_INDEX_KEY.format(form_id=form_id)
questions_key = _QUESTIONS_KEY.format(form_id=form_id)
last_fetch = await redis.get(last_fetch_key)
if last_fetch:
# Try to load existing index for incremental merge
raw_existing = await redis.get(index_key)
if raw_existing is None:
# Index expired but last_fetch still present — fall back to full fetch
logger.info("Tally: last_fetch present but index missing, doing full fetch")
questions, submissions = await _fetch_all_submissions(client, form_id)
email_index = _build_email_index(submissions, questions)
else:
# Incremental fetch: only get new submissions since last fetch
logger.info(f"Tally incremental fetch since {last_fetch}")
questions, new_submissions = await _fetch_all_submissions(
client, form_id, start_date=last_fetch
)
existing_index: dict[str, dict] = json.loads(raw_existing)
if not questions:
raw_q = await redis.get(questions_key)
if raw_q:
questions = json.loads(raw_q)
new_index = _build_email_index(new_submissions, questions)
existing_index.update(new_index)
email_index = existing_index
else:
# Full initial fetch
logger.info("Tally full initial fetch")
questions, submissions = await _fetch_all_submissions(client, form_id)
email_index = _build_email_index(submissions, questions)
# Store in Redis
now = datetime.now(timezone.utc).isoformat()
await redis.setex(index_key, _INDEX_TTL, json.dumps(email_index))
await redis.setex(questions_key, _INDEX_TTL, json.dumps(questions))
await redis.setex(last_fetch_key, _LAST_FETCH_TTL, now)
logger.info(f"Tally cache refreshed: {len(email_index)} emails indexed")
return email_index, questions
async def find_submission_by_email(
form_id: str, email: str
) -> Optional[tuple[dict, list]]:
"""Look up a Tally submission by email. Uses cache when available.
Returns (submission_data, questions) or None.
"""
email_lower = email.lower()
# Try cache first
email_index, questions = await _get_cached_index(form_id)
if email_index is not None and questions is not None:
sub = email_index.get(email_lower)
if sub is not None:
return sub, questions
return None
# Cache miss - refresh
email_index, questions = await _refresh_cache(form_id)
sub = email_index.get(email_lower)
if sub is not None:
return sub, questions
return None
def format_submission_for_llm(submission: dict, questions: list[dict]) -> str:
"""Format a submission as readable Q&A text for LLM consumption."""
# Build question ID -> title lookup
q_titles: dict[str, str] = {}
for q in questions:
q_id = q.get("id", "")
title = q.get("label") or q.get("title") or q.get("name") or f"Question {q_id}"
q_titles[q_id] = title
lines: list[str] = []
responses = submission.get("responses", [])
if isinstance(responses, list):
for resp in responses:
q_id = resp.get("questionId") or resp.get("key") or resp.get("id") or ""
title = q_titles.get(q_id, f"Question {q_id}")
value = resp.get("value") or resp.get("answer") or ""
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
elif isinstance(responses, dict):
for q_id, value in responses.items():
title = q_titles.get(q_id, f"Question {q_id}")
lines.append(f"Q: {title}\nA: {_format_answer(value)}")
return "\n\n".join(lines)
def _format_answer(value: object) -> str:
"""Convert an answer value (str, list, dict, None) to a human-readable string."""
if value is None:
return "(no answer)"
if isinstance(value, list):
return ", ".join(str(v) for v in value)
if isinstance(value, dict):
parts = [f"{k}: {v}" for k, v in value.items() if v]
return "; ".join(parts) if parts else "(no answer)"
return str(value)
_EXTRACTION_PROMPT = """\
You are a business analyst. Given the following form submission data, extract structured business understanding information.
Return a JSON object with ONLY the fields that can be confidently extracted. Use null for fields that cannot be determined.
Fields:
- user_name (string): the person's name
- job_title (string): their job title
- business_name (string): company/business name
- industry (string): industry or sector
- business_size (string): company size e.g. "1-10", "11-50", "51-200"
- user_role (string): their role context e.g. "decision maker", "implementer"
- key_workflows (list of strings): key business workflows
- daily_activities (list of strings): daily activities performed
- pain_points (list of strings): current pain points
- bottlenecks (list of strings): process bottlenecks
- manual_tasks (list of strings): manual/repetitive tasks
- automation_goals (list of strings): desired automation goals
- current_software (list of strings): software/tools currently used
- existing_automation (list of strings): existing automations
- additional_notes (string): any additional context
Form data:
"""
_EXTRACTION_SUFFIX = "\n\nReturn ONLY valid JSON."
async def extract_business_understanding(
formatted_text: str,
) -> BusinessUnderstandingInput:
"""Use an LLM to extract structured business understanding from form text.
Raises on timeout or unparseable response so the caller can handle it.
"""
settings = Settings()
api_key = settings.secrets.open_router_api_key
client = AsyncOpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1")
try:
response = await asyncio.wait_for(
client.chat.completions.create(
model="openai/gpt-4o-mini",
messages=[
{
"role": "user",
"content": f"{_EXTRACTION_PROMPT}{formatted_text}{_EXTRACTION_SUFFIX}",
}
],
response_format={"type": "json_object"},
temperature=0.0,
),
timeout=_LLM_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning("Tally: LLM extraction timed out")
raise
raw = response.choices[0].message.content or "{}"
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Tally: LLM returned invalid JSON, skipping extraction")
raise
# Filter out null values before constructing
cleaned = {k: v for k, v in data.items() if v is not None}
return BusinessUnderstandingInput(**cleaned)
async def populate_understanding_from_tally(user_id: str, email: str) -> None:
"""Main orchestrator: check Tally for a matching submission and populate understanding.
Fire-and-forget safe — all exceptions are caught and logged.
"""
try:
# Check if understanding already exists (idempotency)
existing = await get_business_understanding(user_id)
if existing is not None:
logger.debug(
f"Tally: user {user_id} already has business understanding, skipping"
)
return
# Check API key is configured
settings = Settings()
if not settings.secrets.tally_api_key:
logger.debug("Tally: no API key configured, skipping")
return
# Look up submission by email
masked = _mask_email(email)
result = await find_submission_by_email(TALLY_FORM_ID, email)
if result is None:
logger.debug(f"Tally: no submission found for {masked}")
return
submission, questions = result
logger.info(f"Tally: found submission for {masked}, extracting understanding")
# Format and extract
formatted = format_submission_for_llm(submission, questions)
if not formatted.strip():
logger.warning("Tally: formatted submission was empty, skipping")
return
understanding_input = await extract_business_understanding(formatted)
# Upsert into database
await upsert_business_understanding(user_id, understanding_input)
logger.info(f"Tally: successfully populated understanding for user {user_id}")
except Exception:
logger.exception(f"Tally: error populating understanding for user {user_id}")

View File

@@ -0,0 +1,589 @@
"""Tests for backend.data.tally module."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.tally import (
_EXTRACTION_PROMPT,
_EXTRACTION_SUFFIX,
_build_email_index,
_format_answer,
_make_tally_client,
_mask_email,
_refresh_cache,
extract_business_understanding,
find_submission_by_email,
format_submission_for_llm,
populate_understanding_from_tally,
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
SAMPLE_QUESTIONS = [
{"id": "q1", "label": "What is your name?", "type": "INPUT_TEXT"},
{"id": "q2", "label": "Email address", "type": "INPUT_EMAIL"},
{"id": "q3", "label": "Company name", "type": "INPUT_TEXT"},
{"id": "q4", "label": "Industry", "type": "INPUT_TEXT"},
]
SAMPLE_SUBMISSIONS = [
{
"respondentEmail": None,
"responses": [
{"questionId": "q1", "value": "Alice Smith"},
{"questionId": "q2", "value": "alice@example.com"},
{"questionId": "q3", "value": "Acme Corp"},
{"questionId": "q4", "value": "Technology"},
],
"submittedAt": "2025-01-15T10:00:00Z",
},
{
"respondentEmail": "bob@example.com",
"responses": [
{"questionId": "q1", "value": "Bob Jones"},
{"questionId": "q2", "value": "bob@example.com"},
{"questionId": "q3", "value": "Bob's Burgers"},
{"questionId": "q4", "value": "Food"},
],
"submittedAt": "2025-01-16T10:00:00Z",
},
]
# ── _build_email_index ────────────────────────────────────────────────────────
def test_build_email_index():
index = _build_email_index(SAMPLE_SUBMISSIONS, SAMPLE_QUESTIONS)
assert "alice@example.com" in index
assert "bob@example.com" in index
assert len(index) == 2
def test_build_email_index_case_insensitive():
submissions = [
{
"respondentEmail": None,
"responses": [
{"questionId": "q2", "value": "Alice@Example.COM"},
],
"submittedAt": "2025-01-15T10:00:00Z",
},
]
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
assert "alice@example.com" in index
assert "Alice@Example.COM" not in index
def test_build_email_index_empty():
index = _build_email_index([], SAMPLE_QUESTIONS)
assert index == {}
def test_build_email_index_no_email_field():
questions = [{"id": "q1", "label": "Name", "type": "INPUT_TEXT"}]
submissions = [
{
"responses": [{"questionId": "q1", "value": "Alice"}],
"submittedAt": "2025-01-15T10:00:00Z",
}
]
index = _build_email_index(submissions, questions)
assert index == {}
def test_build_email_index_respondent_email():
"""respondentEmail takes precedence over field scanning."""
submissions = [
{
"respondentEmail": "direct@example.com",
"responses": [
{"questionId": "q2", "value": "field@example.com"},
],
"submittedAt": "2025-01-15T10:00:00Z",
}
]
index = _build_email_index(submissions, SAMPLE_QUESTIONS)
assert "direct@example.com" in index
assert "field@example.com" not in index
# ── format_submission_for_llm ─────────────────────────────────────────────────
def test_format_submission_for_llm():
submission = {
"responses": [
{"questionId": "q1", "value": "Alice Smith"},
{"questionId": "q3", "value": "Acme Corp"},
],
}
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
assert "Q: What is your name?" in result
assert "A: Alice Smith" in result
assert "Q: Company name" in result
assert "A: Acme Corp" in result
def test_format_submission_for_llm_dict_responses():
submission = {
"responses": {
"q1": "Alice Smith",
"q3": "Acme Corp",
},
}
result = format_submission_for_llm(submission, SAMPLE_QUESTIONS)
assert "A: Alice Smith" in result
assert "A: Acme Corp" in result
def test_format_answer_types():
assert _format_answer(None) == "(no answer)"
assert _format_answer("hello") == "hello"
assert _format_answer(["a", "b"]) == "a, b"
assert _format_answer({"key": "val"}) == "key: val"
assert _format_answer(42) == "42"
# ── find_submission_by_email ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_find_submission_by_email_cache_hit():
cached_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
cached_questions = SAMPLE_QUESTIONS
with patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(cached_index, cached_questions),
) as mock_cache:
result = await find_submission_by_email("form123", "alice@example.com")
mock_cache.assert_awaited_once_with("form123")
assert result is not None
sub, questions = result
assert sub["submitted_at"] == "2025-01-15"
@pytest.mark.asyncio
async def test_find_submission_by_email_cache_miss():
refreshed_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
with (
patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(None, None),
),
patch(
"backend.data.tally._refresh_cache",
new_callable=AsyncMock,
return_value=(refreshed_index, SAMPLE_QUESTIONS),
) as mock_refresh,
):
result = await find_submission_by_email("form123", "alice@example.com")
mock_refresh.assert_awaited_once_with("form123")
assert result is not None
@pytest.mark.asyncio
async def test_find_submission_by_email_no_match():
cached_index = {
"alice@example.com": {"responses": [], "submitted_at": "2025-01-15"},
}
with patch(
"backend.data.tally._get_cached_index",
new_callable=AsyncMock,
return_value=(cached_index, SAMPLE_QUESTIONS),
):
result = await find_submission_by_email("form123", "unknown@example.com")
assert result is None
# ── populate_understanding_from_tally ─────────────────────────────────────────
@pytest.mark.asyncio
async def test_populate_understanding_skips_existing():
"""If user already has understanding, skip entirely."""
mock_understanding = MagicMock()
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=mock_understanding,
) as mock_get,
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
) as mock_find,
):
await populate_understanding_from_tally("user-1", "test@example.com")
mock_get.assert_awaited_once_with("user-1")
mock_find.assert_not_awaited()
@pytest.mark.asyncio
async def test_populate_understanding_skips_no_api_key():
"""If no Tally API key, skip gracefully."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = ""
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
) as mock_find,
):
await populate_understanding_from_tally("user-1", "test@example.com")
mock_find.assert_not_awaited()
@pytest.mark.asyncio
async def test_populate_understanding_handles_errors():
"""Must never raise, even on unexpected errors."""
with patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
side_effect=RuntimeError("DB down"),
):
# Should not raise
await populate_understanding_from_tally("user-1", "test@example.com")
@pytest.mark.asyncio
async def test_populate_understanding_full_flow():
"""Happy path: no existing understanding, finds submission, extracts, upserts."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [
{"questionId": "q1", "value": "Alice"},
{"questionId": "q3", "value": "Acme"},
],
}
mock_input = MagicMock()
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
return_value=mock_input,
) as mock_extract,
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
) as mock_upsert,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_extract.assert_awaited_once()
mock_upsert.assert_awaited_once_with("user-1", mock_input)
@pytest.mark.asyncio
async def test_populate_understanding_handles_llm_timeout():
"""LLM timeout is caught and doesn't raise."""
import asyncio
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
submission = {
"responses": [{"questionId": "q1", "value": "Alice"}],
}
with (
patch(
"backend.data.tally.get_business_understanding",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.find_submission_by_email",
new_callable=AsyncMock,
return_value=(submission, SAMPLE_QUESTIONS),
),
patch(
"backend.data.tally.extract_business_understanding",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError(),
),
patch(
"backend.data.tally.upsert_business_understanding",
new_callable=AsyncMock,
) as mock_upsert,
):
await populate_understanding_from_tally("user-1", "alice@example.com")
mock_upsert.assert_not_awaited()
# ── _mask_email ───────────────────────────────────────────────────────────────
def test_mask_email():
assert _mask_email("alice@example.com") == "a***e@example.com"
assert _mask_email("ab@example.com") == "a***@example.com"
assert _mask_email("a@example.com") == "a***@example.com"
def test_mask_email_invalid():
assert _mask_email("no-at-sign") == "***"
# ── Prompt construction (curly-brace safety) ─────────────────────────────────
def test_extraction_prompt_safe_with_curly_braces():
"""User content with curly braces must not break prompt construction.
Previously _EXTRACTION_PROMPT.format(submission_text=...) would raise
KeyError/ValueError if the user text contained { or }.
"""
text_with_braces = "Q: What tools do you use?\nA: We use {Slack} and {{Jira}}"
# This must not raise — the old .format() call would fail here
prompt = f"{_EXTRACTION_PROMPT}{text_with_braces}{_EXTRACTION_SUFFIX}"
assert text_with_braces in prompt
assert prompt.startswith("You are a business analyst.")
assert prompt.endswith("Return ONLY valid JSON.")
def test_extraction_prompt_no_format_placeholders():
"""_EXTRACTION_PROMPT must not contain Python format placeholders."""
assert "{submission_text}" not in _EXTRACTION_PROMPT
# Ensure no stray single-brace placeholders
# (double braces {{ are fine — they're literal in format strings)
import re
single_braces = re.findall(r"(?<!\{)\{[^{].*?\}(?!\})", _EXTRACTION_PROMPT)
assert single_braces == [], f"Found format placeholders: {single_braces}"
# ── extract_business_understanding ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_business_understanding_success():
"""Happy path: LLM returns valid JSON that maps to BusinessUnderstandingInput."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{
"user_name": "Alice",
"business_name": "Acme Corp",
"industry": "Technology",
"pain_points": ["manual reporting"],
}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name == "Acme Corp"
assert result.industry == "Technology"
assert result.pain_points == ["manual reporting"]
@pytest.mark.asyncio
async def test_extract_business_understanding_filters_nulls():
"""Null values from LLM should be excluded from the result."""
mock_choice = MagicMock()
mock_choice.message.content = json.dumps(
{"user_name": "Alice", "business_name": None, "industry": None}
)
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch("backend.data.tally.AsyncOpenAI", return_value=mock_client):
result = await extract_business_understanding("Q: Name?\nA: Alice")
assert result.user_name == "Alice"
assert result.business_name is None
assert result.industry is None
@pytest.mark.asyncio
async def test_extract_business_understanding_invalid_json():
"""Invalid JSON from LLM should raise JSONDecodeError."""
mock_choice = MagicMock()
mock_choice.message.content = "not valid json {"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with (
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
pytest.raises(json.JSONDecodeError),
):
await extract_business_understanding("Q: Name?\nA: Alice")
@pytest.mark.asyncio
async def test_extract_business_understanding_timeout():
"""LLM timeout should propagate as asyncio.TimeoutError."""
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = asyncio.TimeoutError()
with (
patch("backend.data.tally.AsyncOpenAI", return_value=mock_client),
patch("backend.data.tally._LLM_TIMEOUT", 0.001),
pytest.raises(asyncio.TimeoutError),
):
await extract_business_understanding("Q: Name?\nA: Alice")
# ── _refresh_cache ───────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_refresh_cache_full_fetch():
"""First fetch (no last_fetch in Redis) should do a full fetch and store in Redis."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
mock_redis = AsyncMock()
mock_redis.get.return_value = None # No last_fetch, no cached index
questions = SAMPLE_QUESTIONS
submissions = SAMPLE_SUBMISSIONS
with (
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
),
patch(
"backend.data.tally._fetch_all_submissions",
new_callable=AsyncMock,
return_value=(questions, submissions),
) as mock_fetch,
):
index, returned_questions = await _refresh_cache("form123")
mock_fetch.assert_awaited_once()
assert "alice@example.com" in index
assert "bob@example.com" in index
assert returned_questions == questions
# Verify Redis setex was called for index, questions, and last_fetch
assert mock_redis.setex.await_count == 3
@pytest.mark.asyncio
async def test_refresh_cache_incremental_fetch():
"""When last_fetch and index both exist, should do incremental fetch and merge."""
mock_settings = MagicMock()
mock_settings.secrets.tally_api_key = "test-key"
existing_index = {
"old@example.com": {"responses": [], "submitted_at": "2025-01-01"}
}
mock_redis = AsyncMock()
def mock_get(key):
if "last_fetch" in key:
return "2025-01-14T00:00:00Z"
if "email_index" in key:
return json.dumps(existing_index)
if "questions" in key:
return json.dumps(SAMPLE_QUESTIONS)
return None
mock_redis.get.side_effect = mock_get
new_submissions = [SAMPLE_SUBMISSIONS[0]] # Just Alice
with (
patch("backend.data.tally.Settings", return_value=mock_settings),
patch(
"backend.data.tally.get_redis_async",
new_callable=AsyncMock,
return_value=mock_redis,
),
patch(
"backend.data.tally._fetch_all_submissions",
new_callable=AsyncMock,
return_value=(SAMPLE_QUESTIONS, new_submissions),
),
):
index, _ = await _refresh_cache("form123")
# Should contain both old and new entries
assert "old@example.com" in index
assert "alice@example.com" in index
# ── _make_tally_client ───────────────────────────────────────────────────────
def test_make_tally_client_returns_configured_client():
"""_make_tally_client should create a Requests client with auth headers."""
client = _make_tally_client("test-api-key")
assert client.extra_headers is not None
assert client.extra_headers.get("Authorization") == "Bearer test-api-key"
@pytest.mark.asyncio
async def test_fetch_tally_page_uses_provided_client():
"""_fetch_tally_page should use the passed client, not create its own."""
from backend.data.tally import _fetch_tally_page
mock_response = MagicMock()
mock_response.json.return_value = {"submissions": [], "questions": []}
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
result = await _fetch_tally_page(mock_client, "form123", page=1)
mock_client.get.assert_awaited_once()
call_url = mock_client.get.call_args[0][0]
assert "form123" in call_url
assert "page=1" in call_url
assert result == {"submissions": [], "questions": []}

View File

@@ -1,5 +1,6 @@
"""Redis-based distributed locking for cluster coordination."""
import asyncio
import logging
import threading
import time
@@ -7,6 +8,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
logger = logging.getLogger(__name__)
@@ -126,3 +128,124 @@ class ClusterLock:
with self._refresh_lock:
self._last_refresh = 0.0
class AsyncClusterLock:
"""Async Redis-based distributed lock for preventing duplicate execution."""
def __init__(
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
):
self.redis = redis
self.key = key
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
self._refresh_lock = asyncio.Lock()
async def try_acquire(self) -> str | None:
"""Try to acquire the lock.
Returns:
- owner_id (self.owner_id) if successfully acquired
- different owner_id if someone else holds the lock
- None if Redis is unavailable or other error
"""
try:
success = await self.redis.set(
self.key, self.owner_id, nx=True, ex=self.timeout
)
if success:
async with self._refresh_lock:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
current_value = await self.redis.get(self.key)
if current_value:
current_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
return current_owner
# Key doesn't exist but we failed to set it - race condition or Redis issue
return None
except Exception as e:
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
return None
async def refresh(self) -> bool:
"""Refresh lock TTL if we still own it.
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
Async-safe: uses asyncio.Lock to protect _last_refresh access.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period (async-safe read)
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
async with self._refresh_lock:
last_refresh = self._last_refresh
is_rate_limited = (
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = await self.redis.get(self.key)
if not current_value:
async with self._refresh_lock:
self._last_refresh = 0
return False
stored_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
if stored_owner != self.owner_id:
async with self._refresh_lock:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
if is_rate_limited:
return True
# Perform actual refresh
if await self.redis.expire(self.key, self.timeout):
async with self._refresh_lock:
self._last_refresh = current_time
return True
async with self._refresh_lock:
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
async with self._refresh_lock:
self._last_refresh = 0
return False
async def release(self):
"""Release the lock."""
async with self._refresh_lock:
if self._last_refresh == 0:
return
try:
await self.redis.delete(self.key)
except Exception:
pass
async with self._refresh_lock:
self._last_refresh = 0.0

View File

@@ -47,6 +47,7 @@ class ProviderName(str, Enum):
SLANT3D = "slant3d"
SMARTLEAD = "smartlead"
SMTP = "smtp"
TELEGRAM = "telegram"
TWITTER = "twitter"
TODOIST = "todoist"
UNREAL_SPEECH = "unreal_speech"

View File

@@ -15,6 +15,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
from .compass import CompassWebhookManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
from .telegram import TelegramWebhooksManager
webhook_managers.update(
{
@@ -23,6 +24,7 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
TelegramWebhooksManager,
]
}
)

View File

@@ -0,0 +1,242 @@
"""
Telegram Bot API Webhooks Manager.
Handles webhook registration and validation for Telegram bots.
"""
import hmac
import logging
from fastapi import HTTPException, Request
from strenum import StrEnum
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
from backend.util.request import Requests
from backend.util.settings import Config
from ._base import BaseWebhooksManager
from .utils import webhook_ingress_url
logger = logging.getLogger(__name__)
class TelegramWebhookType(StrEnum):
BOT = "bot"
class TelegramWebhooksManager(BaseWebhooksManager):
"""
Manages Telegram bot webhooks.
Telegram webhooks are registered via the setWebhook API method.
Incoming requests are validated using the secret_token header.
"""
PROVIDER_NAME = ProviderName.TELEGRAM
WebhookType = TelegramWebhookType
TELEGRAM_API_BASE = "https://api.telegram.org"
async def get_suitable_auto_webhook(
self,
user_id: str,
credentials: Credentials,
webhook_type: TelegramWebhookType,
resource: str,
events: list[str],
) -> integrations.Webhook:
"""
Telegram only supports one webhook per bot. Instead of creating a new
webhook object when events change (which causes the old one to be pruned
and deregistered — removing the ONLY webhook for the bot), we find the
existing webhook and update its events in place.
"""
app_config = Config()
if not app_config.platform_base_url:
raise MissingConfigError(
"PLATFORM_BASE_URL must be set to use Webhook functionality"
)
# Exact match — no re-registration needed
if webhook := await integrations.find_webhook_by_credentials_and_props(
user_id=user_id,
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=events,
):
return webhook
# Find any existing webhook for the same bot, regardless of events
if existing := await integrations.find_webhook_by_credentials_and_props(
user_id=user_id,
credentials_id=credentials.id,
webhook_type=webhook_type,
resource=resource,
events=None, # Ignore events for this lookup
):
# Re-register with Telegram using the same URL but new allowed_updates
ingress_url = webhook_ingress_url(self.PROVIDER_NAME, existing.id)
_, config = await self._register_webhook(
credentials,
webhook_type,
resource,
events,
ingress_url,
existing.secret,
)
return await integrations.update_webhook(
existing.id, events=events, config=config
)
# No existing webhook at all — create a new one
return await self._create_webhook(
user_id=user_id,
webhook_type=webhook_type,
events=events,
resource=resource,
credentials=credentials,
)
@classmethod
async def validate_payload(
cls,
webhook: integrations.Webhook,
request: Request,
credentials: Credentials | None,
) -> tuple[dict, str]:
"""
Validates incoming Telegram webhook request.
Telegram sends X-Telegram-Bot-Api-Secret-Token header when secret_token
was set in setWebhook call.
Returns:
tuple: (payload dict, event_type string)
"""
# Verify secret token header
secret_header = request.headers.get("X-Telegram-Bot-Api-Secret-Token")
if not secret_header or not hmac.compare_digest(secret_header, webhook.secret):
raise HTTPException(
status_code=403,
detail="Invalid or missing X-Telegram-Bot-Api-Secret-Token",
)
payload = await request.json()
# Determine event type based on update content
if "message" in payload:
message = payload["message"]
if "text" in message:
event_type = "message.text"
elif "photo" in message:
event_type = "message.photo"
elif "voice" in message:
event_type = "message.voice"
elif "audio" in message:
event_type = "message.audio"
elif "document" in message:
event_type = "message.document"
elif "video" in message:
event_type = "message.video"
else:
logger.warning(
"Unknown Telegram webhook payload type; "
f"message.keys() = {message.keys()}"
)
event_type = "message.other"
elif "edited_message" in payload:
event_type = "message.edited_message"
elif "message_reaction" in payload:
event_type = "message_reaction"
else:
event_type = "unknown"
return payload, event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: TelegramWebhookType,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""
Register webhook with Telegram using setWebhook API.
Args:
credentials: Bot token credentials
webhook_type: Type of webhook (always BOT for Telegram)
resource: Resource identifier (unused for Telegram, bots are global)
events: Events to subscribe to
ingress_url: URL to receive webhook payloads
secret: Secret token for request validation
Returns:
tuple: (provider_webhook_id, config dict)
"""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key (bot token) is required for Telegram webhooks")
token = credentials.api_key.get_secret_value()
url = f"{self.TELEGRAM_API_BASE}/bot{token}/setWebhook"
# Map event filter to Telegram allowed_updates
if events:
telegram_updates: set[str] = set()
for event in events:
telegram_updates.add(event.split(".")[0])
# "message.edited_message" requires the "edited_message" update type
if "edited_message" in event:
telegram_updates.add("edited_message")
sorted_updates = sorted(telegram_updates)
else:
sorted_updates = ["message", "message_reaction"]
webhook_data = {
"url": ingress_url,
"secret_token": secret,
"allowed_updates": sorted_updates,
}
response = await Requests().post(url, json=webhook_data)
result = response.json()
if not result.get("ok"):
error_desc = result.get("description", "Unknown error")
raise ValueError(f"Failed to set Telegram webhook: {error_desc}")
# Telegram doesn't return a webhook ID, use empty string
config = {
"url": ingress_url,
"allowed_updates": webhook_data["allowed_updates"],
}
return "", config
async def _deregister_webhook(
self, webhook: integrations.Webhook, credentials: Credentials
) -> None:
"""
Deregister webhook by calling setWebhook with empty URL.
This removes the webhook from Telegram's servers.
"""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("API key (bot token) is required for Telegram webhooks")
token = credentials.api_key.get_secret_value()
url = f"{self.TELEGRAM_API_BASE}/bot{token}/setWebhook"
# Setting empty URL removes the webhook
response = await Requests().post(url, json={"url": ""})
result = response.json()
if not result.get("ok"):
error_desc = result.get("description", "Unknown error")
logger.warning(f"Failed to deregister Telegram webhook: {error_desc}")

View File

@@ -372,7 +372,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for the Agent Generator service",
)
agentgenerator_timeout: int = Field(
default=600,
default=1800,
description="The timeout in seconds for Agent Generator service requests (includes retries for rate limits)",
)
agentgenerator_use_dummy: bool = Field(
@@ -691,6 +691,15 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
screenshotone_api_key: str = Field(default="", description="ScreenshotOne API Key")
tally_api_key: str = Field(
default="",
description="Tally API key for form submission lookup on signup",
)
tally_form_id: str = Field(
default="npGe0q",
description="Tally form ID for signup business understanding form",
)
apollo_api_key: str = Field(default="", description="Apollo API Key")
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")

View File

@@ -0,0 +1,33 @@
-- AlterTable
ALTER TABLE "LibraryAgent" ADD COLUMN "folderId" TEXT;
-- CreateTable
CREATE TABLE "LibraryFolder" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"name" TEXT NOT NULL,
"icon" TEXT,
"color" TEXT,
"parentId" TEXT,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "LibraryFolder_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "LibraryFolder_userId_parentId_name_key" ON "LibraryFolder"("userId", "parentId", "name");
-- CreateIndex
CREATE INDEX "LibraryAgent_folderId_idx" ON "LibraryAgent"("folderId");
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_folderId_fkey" FOREIGN KEY ("folderId") REFERENCES "LibraryFolder"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryFolder" ADD CONSTRAINT "LibraryFolder_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryFolder" ADD CONSTRAINT "LibraryFolder_parentId_fkey" FOREIGN KEY ("parentId") REFERENCES "LibraryFolder"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -0,0 +1,97 @@
-- This migration creates a materialized view for suggested blocks based on execution counts
-- The view aggregates execution counts per block for the last 14 days
--
-- IMPORTANT: For production environments, pg_cron is REQUIRED for automatic refresh
-- Prerequisites for production:
-- 1. pg_cron extension must be installed: CREATE EXTENSION pg_cron;
-- 2. pg_cron must be configured in postgresql.conf:
-- shared_preload_libraries = 'pg_cron'
-- cron.database_name = 'your_database_name'
--
-- For development environments without pg_cron:
-- The migration will succeed but you must manually refresh views with:
-- SET search_path TO platform;
-- SELECT refresh_suggested_blocks_view();
-- Check if pg_cron extension is installed
DO $$
DECLARE
has_pg_cron BOOLEAN;
BEGIN
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
IF NOT has_pg_cron THEN
RAISE WARNING 'pg_cron is not installed. Materialized view will be created but will NOT refresh automatically. For production, install pg_cron. For development, manually refresh with: SELECT refresh_suggested_blocks_view();';
END IF;
END
$$;
-- Create materialized view for suggested blocks based on execution counts in last 14 days
-- The 14-day threshold is hardcoded to ensure consistent behavior
CREATE MATERIALIZED VIEW IF NOT EXISTS "mv_suggested_blocks" AS
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM "AgentNodeExecution" execution
JOIN "AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= (NOW() - INTERVAL '14 days')
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
-- Create unique index for concurrent refresh support
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_suggested_blocks_block_id" ON "mv_suggested_blocks"("block_id");
-- Create refresh function
CREATE OR REPLACE FUNCTION refresh_suggested_blocks_view()
RETURNS void
LANGUAGE plpgsql
AS $$
DECLARE
target_schema text := current_schema();
BEGIN
-- Use CONCURRENTLY for better performance during refresh
REFRESH MATERIALIZED VIEW CONCURRENTLY "mv_suggested_blocks";
RAISE NOTICE 'Suggested blocks materialized view refreshed in schema % at %', target_schema, NOW();
EXCEPTION
WHEN OTHERS THEN
-- Fallback to non-concurrent refresh if concurrent fails
REFRESH MATERIALIZED VIEW "mv_suggested_blocks";
RAISE NOTICE 'Suggested blocks materialized view refreshed (non-concurrent) in schema % at %. Concurrent refresh failed due to: %', target_schema, NOW(), SQLERRM;
END;
$$;
-- Initial refresh of the materialized view
SELECT refresh_suggested_blocks_view();
-- Schedule automatic refresh every hour (only if pg_cron is available)
DO $$
DECLARE
has_pg_cron BOOLEAN;
current_schema_name text := current_schema();
job_name text;
BEGIN
-- Check if pg_cron extension exists
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
IF has_pg_cron THEN
job_name := format('refresh-suggested-blocks_%s', current_schema_name);
-- Try to unschedule existing job (ignore errors if it doesn't exist)
BEGIN
PERFORM cron.unschedule(job_name);
EXCEPTION WHEN OTHERS THEN
NULL;
END;
-- Schedule the new job to run every hour
PERFORM cron.schedule(
job_name,
'0 * * * *', -- Every hour at minute 0
format('SET search_path TO %I; SELECT refresh_suggested_blocks_view();', current_schema_name)
);
RAISE NOTICE 'Scheduled job %; runs every hour for schema %', job_name, current_schema_name;
ELSE
RAISE WARNING 'Automatic refresh NOT configured - pg_cron is not available. Manually refresh with: SELECT refresh_suggested_blocks_view();';
END IF;
END;
$$;

View File

@@ -0,0 +1,7 @@
-- This migration adds more than one value to an enum.
-- With PostgreSQL versions 11 and earlier, this is not possible
-- in a single migration. This can be worked around by creating
-- multiple migrations, each migration adding only one value to
-- the enum.
ALTER TYPE "APIKeyPermission" ADD VALUE 'WRITE_GRAPH';
ALTER TYPE "APIKeyPermission" ADD VALUE 'WRITE_LIBRARY';

View File

@@ -1610,6 +1610,101 @@ mccabe = ">=0.7.0,<0.8.0"
pycodestyle = ">=2.14.0,<2.15.0"
pyflakes = ">=3.4.0,<3.5.0"
[[package]]
name = "fonttools"
version = "4.61.1"
description = "Tools to manipulate font files"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "fonttools-4.61.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24"},
{file = "fonttools-4.61.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958"},
{file = "fonttools-4.61.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da"},
{file = "fonttools-4.61.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6"},
{file = "fonttools-4.61.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1"},
{file = "fonttools-4.61.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881"},
{file = "fonttools-4.61.1-cp310-cp310-win32.whl", hash = "sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47"},
{file = "fonttools-4.61.1-cp310-cp310-win_amd64.whl", hash = "sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6"},
{file = "fonttools-4.61.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09"},
{file = "fonttools-4.61.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37"},
{file = "fonttools-4.61.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb"},
{file = "fonttools-4.61.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9"},
{file = "fonttools-4.61.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87"},
{file = "fonttools-4.61.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56"},
{file = "fonttools-4.61.1-cp311-cp311-win32.whl", hash = "sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a"},
{file = "fonttools-4.61.1-cp311-cp311-win_amd64.whl", hash = "sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7"},
{file = "fonttools-4.61.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e"},
{file = "fonttools-4.61.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2"},
{file = "fonttools-4.61.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796"},
{file = "fonttools-4.61.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d"},
{file = "fonttools-4.61.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8"},
{file = "fonttools-4.61.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0"},
{file = "fonttools-4.61.1-cp312-cp312-win32.whl", hash = "sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261"},
{file = "fonttools-4.61.1-cp312-cp312-win_amd64.whl", hash = "sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9"},
{file = "fonttools-4.61.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c"},
{file = "fonttools-4.61.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e"},
{file = "fonttools-4.61.1-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5"},
{file = "fonttools-4.61.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd"},
{file = "fonttools-4.61.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3"},
{file = "fonttools-4.61.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d"},
{file = "fonttools-4.61.1-cp313-cp313-win32.whl", hash = "sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c"},
{file = "fonttools-4.61.1-cp313-cp313-win_amd64.whl", hash = "sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b"},
{file = "fonttools-4.61.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd"},
{file = "fonttools-4.61.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e"},
{file = "fonttools-4.61.1-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c"},
{file = "fonttools-4.61.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75"},
{file = "fonttools-4.61.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063"},
{file = "fonttools-4.61.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2"},
{file = "fonttools-4.61.1-cp314-cp314-win32.whl", hash = "sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c"},
{file = "fonttools-4.61.1-cp314-cp314-win_amd64.whl", hash = "sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c"},
{file = "fonttools-4.61.1-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa"},
{file = "fonttools-4.61.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91"},
{file = "fonttools-4.61.1-cp314-cp314t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19"},
{file = "fonttools-4.61.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba"},
{file = "fonttools-4.61.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7"},
{file = "fonttools-4.61.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118"},
{file = "fonttools-4.61.1-cp314-cp314t-win32.whl", hash = "sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5"},
{file = "fonttools-4.61.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b"},
{file = "fonttools-4.61.1-py3-none-any.whl", hash = "sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371"},
{file = "fonttools-4.61.1.tar.gz", hash = "sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69"},
]
[package.extras]
all = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\"", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.45.0)", "unicodedata2 (>=17.0.0) ; python_version <= \"3.14\"", "xattr ; sys_platform == \"darwin\"", "zopfli (>=0.1.4)"]
graphite = ["lz4 (>=1.7.4.2)"]
interpolatable = ["munkres ; platform_python_implementation == \"PyPy\"", "pycairo", "scipy ; platform_python_implementation != \"PyPy\""]
lxml = ["lxml (>=4.0)"]
pathops = ["skia-pathops (>=0.5.0)"]
plot = ["matplotlib"]
repacker = ["uharfbuzz (>=0.45.0)"]
symfont = ["sympy"]
type1 = ["xattr ; sys_platform == \"darwin\""]
unicode = ["unicodedata2 (>=17.0.0) ; python_version <= \"3.14\""]
woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"]
[[package]]
name = "fpdf2"
version = "2.8.6"
description = "Simple & fast PDF generation for Python"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "fpdf2-2.8.6-py3-none-any.whl", hash = "sha256:464658b896c6b0fcbf883abb316b8f0a52d582eb959d71822ba254d6c790bfdd"},
{file = "fpdf2-2.8.6.tar.gz", hash = "sha256:5132f26bbeee69a7ca6a292e4da1eb3241147b5aea9348b35e780ecd02bf5fc2"},
]
[package.dependencies]
defusedxml = "*"
fonttools = ">=4.34.0"
Pillow = ">=8.3.2,<9.2.dev0 || >=9.3.dev0"
[package.extras]
dev = ["bandit", "black", "mypy", "pre-commit", "pylint", "pyright", "semgrep", "zizmor"]
docs = ["lxml", "mkdocs", "mkdocs-git-revision-date-localized-plugin", "mkdocs-include-markdown-plugin", "mkdocs-macros-plugin", "mkdocs-material", "mkdocs-minify-plugin", "mkdocs-redirects", "mkdocs-with-pdf", "mknotebooks", "pdoc3"]
test = ["brotli", "camelot-py[base]", "endesive[full]", "pytest", "pytest-cov", "qrcode", "tabula-py", "typing-extensions (>=4.0) ; python_version < \"3.11\"", "uharfbuzz"]
[[package]]
name = "frozenlist"
version = "1.8.0"
@@ -8530,4 +8625,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "3ef62836d8321b9a3b8e897dade8dc6ca9022fd9468c53f384b0871b521ab343"
content-hash = "3869bc3fb8ea50e7101daffce13edbe563c8af568cb751adfa31fb9bb5c8318a"

View File

@@ -89,6 +89,7 @@ croniter = "^6.0.0"
stagehand = "^0.5.1"
gravitas-md2gdocs = "^0.1.0"
posthog = "^7.6.0"
fpdf2 = "^2.8.6"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"

View File

@@ -51,6 +51,7 @@ model User {
ChatSessions ChatSession[]
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
LibraryFolders LibraryFolder[]
Profile Profile[]
UserOnboarding UserOnboarding?
@@ -395,6 +396,9 @@ model LibraryAgent {
creatorId String?
Creator Profile? @relation(fields: [creatorId], references: [id])
folderId String?
Folder LibraryFolder? @relation(fields: [folderId], references: [id], onDelete: Restrict)
useGraphIsActiveVersion Boolean @default(false)
isFavorite Boolean @default(false)
@@ -407,6 +411,30 @@ model LibraryAgent {
@@unique([userId, agentGraphId, agentGraphVersion])
@@index([agentGraphId, agentGraphVersion])
@@index([creatorId])
@@index([folderId])
}
model LibraryFolder {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
name String
icon String?
color String?
parentId String?
Parent LibraryFolder? @relation("FolderHierarchy", fields: [parentId], references: [id], onDelete: Cascade)
Children LibraryFolder[] @relation("FolderHierarchy")
isDeleted Boolean @default(false)
LibraryAgents LibraryAgent[]
@@unique([userId, parentId, name]) // Name unique per parent per user
}
////////////////////////////////////////////////////////////
@@ -920,6 +948,17 @@ view mv_review_stats {
// Refresh uses CONCURRENTLY to avoid blocking reads
}
// Note: This is actually a MATERIALIZED VIEW in the database
// Refreshed automatically every hour via pg_cron (with fallback to manual refresh)
view mv_suggested_blocks {
block_id String @unique
execution_count Int
// Pre-aggregated execution counts per block for the last 14 days
// Used by builder suggestions for ordering blocks by popularity
// Refresh uses CONCURRENTLY to avoid blocking reads
}
model StoreListing {
id String @id @default(uuid())
createdAt DateTime @default(now())
@@ -1091,9 +1130,11 @@ enum APIKeyPermission {
IDENTITY // Info about the authenticated user
EXECUTE_GRAPH // Can execute agent graphs
READ_GRAPH // Can get graph versions and details
WRITE_GRAPH // Can create and update agent graphs
EXECUTE_BLOCK // Can execute individual blocks
READ_BLOCK // Can get block information
READ_STORE // Can read store agents and creators
WRITE_LIBRARY // Can add agents to library
USE_TOOLS // Can use chat tools via external API
MANAGE_INTEGRATIONS // Can initiate OAuth flows and complete them
READ_INTEGRATIONS // Can list credentials and providers

View File

@@ -38,6 +38,8 @@
"can_access_graph": true,
"is_latest_version": true,
"is_favorite": false,
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"settings": {
"human_in_the_loop_safe_mode": true,
@@ -83,6 +85,8 @@
"can_access_graph": false,
"is_latest_version": true,
"is_favorite": false,
"folder_id": null,
"folder_name": null,
"recommended_schedule_cron": null,
"settings": {
"human_in_the_loop_safe_mode": true,

View File

@@ -109,7 +109,7 @@ class TestGenerateAgent:
instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions, None, None, None)
mock_external.assert_called_once_with(instructions, None)
assert result is not None
assert result["name"] == "Test Agent"
assert "id" in result
@@ -173,9 +173,7 @@ class TestGenerateAgentPatch:
current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with(
"Add a node", current_agent, None, None, None
)
mock_external.assert_called_once_with("Add a node", current_agent, None)
assert result == expected_result
@pytest.mark.asyncio

View File

@@ -1,349 +0,0 @@
#!/usr/bin/env python3
"""
Integration test for the requeue fix implementation.
Tests actual RabbitMQ behavior to verify that republishing sends messages to back of queue.
"""
import json
import time
from threading import Event
from typing import List
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
class QueueOrderTester:
"""Helper class to test message ordering in RabbitMQ using a dedicated test queue."""
def __init__(self):
self.received_messages: List[dict] = []
self.stop_consuming = Event()
self.queue_client = SyncRabbitMQ(create_execution_queue_config())
self.queue_client.connect()
# Use a dedicated test queue name to avoid conflicts
self.test_queue_name = "test_requeue_ordering"
self.test_exchange = "test_exchange"
self.test_routing_key = "test.requeue"
def setup_queue(self):
"""Set up a dedicated test queue for testing."""
channel = self.queue_client.get_channel()
# Declare test exchange
channel.exchange_declare(
exchange=self.test_exchange, exchange_type="direct", durable=True
)
# Declare test queue
channel.queue_declare(
queue=self.test_queue_name, durable=True, auto_delete=False
)
# Bind queue to exchange
channel.queue_bind(
exchange=self.test_exchange,
queue=self.test_queue_name,
routing_key=self.test_routing_key,
)
# Purge the queue to start fresh
channel.queue_purge(self.test_queue_name)
print(f"✅ Test queue {self.test_queue_name} setup and purged")
def create_test_message(self, message_id: str, user_id: str = "test-user") -> str:
"""Create a test graph execution message."""
return json.dumps(
{
"graph_exec_id": f"exec-{message_id}",
"graph_id": f"graph-{message_id}",
"user_id": user_id,
"execution_context": {"timezone": "UTC"},
"nodes_input_masks": {},
"starting_nodes_input": [],
}
)
def publish_message(self, message: str):
"""Publish a message to the test queue."""
channel = self.queue_client.get_channel()
channel.basic_publish(
exchange=self.test_exchange,
routing_key=self.test_routing_key,
body=message,
)
def consume_messages(self, max_messages: int = 10, timeout: float = 5.0):
"""Consume messages and track their order."""
def callback(ch, method, properties, body):
try:
message_data = json.loads(body.decode())
self.received_messages.append(message_data)
ch.basic_ack(delivery_tag=method.delivery_tag)
if len(self.received_messages) >= max_messages:
self.stop_consuming.set()
except Exception as e:
print(f"Error processing message: {e}")
ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False)
# Use synchronous consumption with blocking
channel = self.queue_client.get_channel()
# Check if there are messages in the queue first
method_frame, header_frame, body = channel.basic_get(
queue=self.test_queue_name, auto_ack=False
)
if method_frame:
# There are messages, set up consumer
channel.basic_nack(
delivery_tag=method_frame.delivery_tag, requeue=True
) # Put message back
# Set up consumer
channel.basic_consume(
queue=self.test_queue_name,
on_message_callback=callback,
)
# Consume with timeout
start_time = time.time()
while (
not self.stop_consuming.is_set()
and (time.time() - start_time) < timeout
and len(self.received_messages) < max_messages
):
try:
channel.connection.process_data_events(time_limit=0.1)
except Exception as e:
print(f"Error during consumption: {e}")
break
# Cancel the consumer
try:
channel.cancel()
except Exception:
pass
else:
# No messages in queue - this might be expected for some tests
pass
return self.received_messages
def cleanup(self):
"""Clean up test resources."""
try:
channel = self.queue_client.get_channel()
channel.queue_delete(queue=self.test_queue_name)
channel.exchange_delete(exchange=self.test_exchange)
print(f"✅ Test queue {self.test_queue_name} cleaned up")
except Exception as e:
print(f"⚠️ Cleanup issue: {e}")
def test_queue_ordering_behavior():
"""
Integration test to verify that our republishing method sends messages to back of queue.
This tests the actual fix for the rate limiting queue blocking issue.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing actual RabbitMQ queue ordering behavior...")
# Test 1: Normal FIFO behavior
print("1. Testing normal FIFO queue behavior")
# Publish messages in order: A, B, C
msg_a = tester.create_test_message("A")
msg_b = tester.create_test_message("B")
msg_c = tester.create_test_message("C")
tester.publish_message(msg_a)
tester.publish_message(msg_b)
tester.publish_message(msg_c)
# Consume and verify FIFO order: A, B, C
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3, f"Expected 3 messages, got {len(messages)}"
assert (
messages[0]["graph_exec_id"] == "exec-A"
), f"First message should be A, got {messages[0]['graph_exec_id']}"
assert (
messages[1]["graph_exec_id"] == "exec-B"
), f"Second message should be B, got {messages[1]['graph_exec_id']}"
assert (
messages[2]["graph_exec_id"] == "exec-C"
), f"Third message should be C, got {messages[2]['graph_exec_id']}"
print("✅ FIFO order confirmed: A -> B -> C")
# Test 2: Rate limiting simulation - the key test!
print("2. Testing rate limiting fix scenario")
# Simulate the scenario where user1 is rate limited
user1_msg = tester.create_test_message("RATE-LIMITED", "user1")
user2_msg1 = tester.create_test_message("USER2-1", "user2")
user2_msg2 = tester.create_test_message("USER2-2", "user2")
# Initially publish user1 message (gets consumed, then rate limited on retry)
tester.publish_message(user1_msg)
# Other users publish their messages
tester.publish_message(user2_msg1)
tester.publish_message(user2_msg2)
# Now simulate: user1 message gets "requeued" using our new republishing method
# This is what happens in manager.py when requeue_by_republishing=True
tester.publish_message(user1_msg) # Goes to back via our method
# Expected order: RATE-LIMITED, USER2-1, USER2-2, RATE-LIMITED (republished to back)
# This shows that user2 messages get processed instead of being blocked
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=4)
assert len(messages) == 4, f"Expected 4 messages, got {len(messages)}"
# The key verification: user2 messages are NOT blocked by user1's rate-limited message
user2_messages = [msg for msg in messages if msg["user_id"] == "user2"]
assert len(user2_messages) == 2, "Both user2 messages should be processed"
assert user2_messages[0]["graph_exec_id"] == "exec-USER2-1"
assert user2_messages[1]["graph_exec_id"] == "exec-USER2-2"
print("✅ Rate limiting fix confirmed: user2 executions NOT blocked by user1")
# Test 3: Verify our method behaves like going to back of queue
print("3. Testing republishing sends messages to back")
# Start with message X in queue
msg_x = tester.create_test_message("X")
tester.publish_message(msg_x)
# Add message Y
msg_y = tester.create_test_message("Y")
tester.publish_message(msg_y)
# Republish X (simulates requeue using our method)
tester.publish_message(msg_x)
# Expected: X, Y, X (X was republished to back)
tester.received_messages = []
tester.stop_consuming.clear()
messages = tester.consume_messages(max_messages=3)
assert len(messages) == 3
# Y should come before the republished X
y_index = next(
i for i, msg in enumerate(messages) if msg["graph_exec_id"] == "exec-Y"
)
republished_x_index = next(
i
for i, msg in enumerate(messages[1:], 1)
if msg["graph_exec_id"] == "exec-X"
)
assert (
y_index < republished_x_index
), f"Y should come before republished X, but got order: {[m['graph_exec_id'] for m in messages]}"
print("✅ Republishing confirmed: messages go to back of queue")
print("🎉 All integration tests passed!")
print("🎉 Our republishing method works correctly with real RabbitMQ")
print("🎉 Queue blocking issue is fixed!")
finally:
tester.cleanup()
def test_traditional_requeue_behavior():
"""
Test that traditional requeue (basic_nack with requeue=True) sends messages to FRONT of queue.
This validates our hypothesis about why queue blocking occurs.
"""
tester = QueueOrderTester()
try:
tester.setup_queue()
print("🧪 Testing traditional requeue behavior (basic_nack with requeue=True)")
# Step 1: Publish message A
msg_a = tester.create_test_message("A")
tester.publish_message(msg_a)
# Step 2: Publish message B
msg_b = tester.create_test_message("B")
tester.publish_message(msg_b)
# Step 3: Consume message A and requeue it using traditional method
channel = tester.queue_client.get_channel()
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=False
)
assert method_frame is not None, "Should have received message A"
consumed_msg = json.loads(body.decode())
assert (
consumed_msg["graph_exec_id"] == "exec-A"
), f"Should have consumed message A, got {consumed_msg['graph_exec_id']}"
# Traditional requeue: basic_nack with requeue=True (sends to FRONT)
channel.basic_nack(delivery_tag=method_frame.delivery_tag, requeue=True)
print(f"🔄 Traditional requeue (to FRONT): {consumed_msg['graph_exec_id']}")
# Step 4: Consume all messages using basic_get for reliability
received_messages = []
# Get first message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# Get second message
method_frame, header_frame, body = channel.basic_get(
queue=tester.test_queue_name, auto_ack=True
)
if method_frame:
msg = json.loads(body.decode())
received_messages.append(msg)
# CRITICAL ASSERTION: Traditional requeue should put A at FRONT
# Expected order: A (requeued to front), B
assert (
len(received_messages) == 2
), f"Expected 2 messages, got {len(received_messages)}"
first_msg = received_messages[0]["graph_exec_id"]
second_msg = received_messages[1]["graph_exec_id"]
# This is the critical test: requeued message A should come BEFORE B
assert (
first_msg == "exec-A"
), f"Traditional requeue should put A at FRONT, but first message was: {first_msg}"
assert (
second_msg == "exec-B"
), f"B should come after requeued A, but second message was: {second_msg}"
print(
"✅ HYPOTHESIS CONFIRMED: Traditional requeue sends messages to FRONT of queue"
)
print(f" Order: {first_msg} (requeued to front) → {second_msg}")
print(" This explains why rate-limited messages block other users!")
finally:
tester.cleanup()
if __name__ == "__main__":
test_queue_ordering_behavior()

View File

@@ -6,6 +6,7 @@ const config: StorybookConfig = {
"../src/components/tokens/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/atoms/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
],
addons: [
"@storybook/addon-a11y",

View File

@@ -32,6 +32,7 @@
"dependencies": {
"@ai-sdk/react": "3.0.61",
"@faker-js/faker": "10.0.0",
"@ferrucc-io/emoji-picker": "0.0.48",
"@hookform/resolvers": "5.2.2",
"@next/third-parties": "15.4.6",
"@phosphor-icons/react": "2.1.10",

View File

@@ -18,6 +18,9 @@ importers:
'@faker-js/faker':
specifier: 10.0.0
version: 10.0.0
'@ferrucc-io/emoji-picker':
specifier: 0.0.48
version: 0.0.48(@babel/core@7.28.5)(@babel/template@7.27.2)(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(tailwindcss@3.4.17)
'@hookform/resolvers':
specifier: 5.2.2
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
@@ -1507,6 +1510,14 @@ packages:
resolution: {integrity: sha512-UollFEUkVXutsaP+Vndjxar40Gs5JL2HeLcl8xO1QAjJgOdhc3OmBFWyEylS+RddWaaBiAzH+5/17PLQJwDiLw==}
engines: {node: ^20.19.0 || ^22.13.0 || ^23.5.0 || >=24.0.0, npm: '>=10'}
'@ferrucc-io/emoji-picker@0.0.48':
resolution: {integrity: sha512-DJ5u+6VLF9OK7x+S/luwrVb5CHC6W16jL5b8vBUYNpxKWSuFgyliDHVtw1SGe6+dr5RUbf8WQwPJdKZmU3Ittg==}
engines: {node: '>=18'}
peerDependencies:
react: ^18.2.0 || ^19.0.0
react-dom: ^18.2.0 || ^19.0.0
tailwindcss: '>=3.0.0'
'@floating-ui/core@1.7.3':
resolution: {integrity: sha512-sGnvb5dmrJaKEZ+LDIpguvdX3bDlEllmv4/ClQ9awcmCZrlx5jQyyMWFM5kBI+EyNOCDDiKk8il0zeuX3Zlg/w==}
@@ -3114,6 +3125,10 @@ packages:
'@shikijs/vscode-textmate@10.0.2':
resolution: {integrity: sha512-83yeghZ2xxin3Nj8z1NMd/NCuca+gsYXswywDy5bHvwlWL8tpTQmzGeUuHd9FC3E/SBEMvzJRwWEOz5gGes9Qg==}
'@sindresorhus/is@4.6.0':
resolution: {integrity: sha512-t09vSN3MdfsyCHoFcTRCH/iUtG7OJ0CsjzB8cjAmKc/va/kIgeDI/TxsigdncE/4be734m0cvIYwNaV4i2XqAw==}
engines: {node: '>=10'}
'@standard-schema/spec@1.0.0':
resolution: {integrity: sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==}
@@ -3376,10 +3391,19 @@ packages:
react: '>=16.8'
react-dom: '>=16.8'
'@tanstack/react-virtual@3.13.18':
resolution: {integrity: sha512-dZkhyfahpvlaV0rIKnvQiVoWPyURppl6w4m9IwMDpuIjcJ1sD9YGWrt0wISvgU7ewACXx2Ct46WPgI6qAD4v6A==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
'@tanstack/table-core@8.21.3':
resolution: {integrity: sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==}
engines: {node: '>=12'}
'@tanstack/virtual-core@3.13.18':
resolution: {integrity: sha512-Mx86Hqu1k39icq2Zusq+Ey2J6dDWTjDvEv43PJtRCoEYTLyfaPnxIQ6iy7YAOK0NV/qOEmZQ/uCufrppZxTgcg==}
'@testing-library/dom@10.4.1':
resolution: {integrity: sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==}
engines: {node: '>=18'}
@@ -4373,6 +4397,10 @@ packages:
resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==}
engines: {node: '>=10'}
char-regex@1.0.2:
resolution: {integrity: sha512-kWWXztvZ5SBQV+eRgKFeh8q5sLuZY2+8WUIzlxWVTg+oGwY14qylx1KbKzHd8P6ZYkAg0xyIDU9JMHhyJMZ1jw==}
engines: {node: '>=10'}
character-entities-html4@2.1.0:
resolution: {integrity: sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==}
@@ -4990,6 +5018,9 @@ packages:
emoji-regex@9.2.2:
resolution: {integrity: sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==}
emojilib@2.4.0:
resolution: {integrity: sha512-5U0rVMU5Y2n2+ykNLQqMoqklN9ICBT/KsvC1Gz6vqHbz2AXXGkG+Pm5rMWk/8Vjrr/mY9985Hi8DYzn1F09Nyw==}
emojis-list@3.0.0:
resolution: {integrity: sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q==}
engines: {node: '>= 4'}
@@ -5970,6 +6001,24 @@ packages:
resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==}
hasBin: true
jotai@2.17.1:
resolution: {integrity: sha512-TFNZZDa/0ewCLQyRC/Sq9crtixNj/Xdf/wmj9631xxMuKToVJZDbqcHIYN0OboH+7kh6P6tpIK7uKWClj86PKw==}
engines: {node: '>=12.20.0'}
peerDependencies:
'@babel/core': '>=7.0.0'
'@babel/template': '>=7.0.0'
'@types/react': '>=17.0.0'
react: '>=17.0.0'
peerDependenciesMeta:
'@babel/core':
optional: true
'@babel/template':
optional: true
'@types/react':
optional: true
react:
optional: true
js-tokens@4.0.0:
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
@@ -6588,6 +6637,10 @@ packages:
node-abort-controller@3.1.1:
resolution: {integrity: sha512-AGK2yQKIjRuqnc6VkX2Xj5d+QW8xZ87pa1UK6yA6ouUyuxfHuMP6umE5QK7UmTeOAymo+Zx1Fxiuw9rVx8taHQ==}
node-emoji@2.2.0:
resolution: {integrity: sha512-Z3lTE9pLaJF47NyMhd4ww1yFTAP8YhYI8SleJiHzM46Fgpm5cnNzSl9XfzFNqbaz+VlJrIj3fXQ4DeN1Rjm6cw==}
engines: {node: '>=18'}
node-fetch-h2@2.3.0:
resolution: {integrity: sha512-ofRW94Ab0T4AOh5Fk8t0h8OBWrmjb0SSB20xh1H8YnPV9EJ+f5AMoYSUQ2zgJ4Iq2HAK0I2l5/Nequ8YzFS3Hg==}
engines: {node: 4.x || >=6.0.0}
@@ -7686,6 +7739,10 @@ packages:
resolution: {integrity: sha512-LH7FpTAkeD+y5xQC4fzS+tFtaNlvt3Ib1zKzvhjv/Y+cioV4zIuw4IZr2yhRLu67CWL7FR9/6KXKnjRoZTvGGQ==}
engines: {node: '>=12'}
skin-tone@2.0.0:
resolution: {integrity: sha512-kUMbT1oBJCpgrnKoSr0o6wPtvRWT9W9UKvGLwfJYO2WuahZRHOpEyL1ckyMGgMWh0UdpmaoFqKKD29WTomNEGA==}
engines: {node: '>=8'}
slash@3.0.0:
resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==}
engines: {node: '>=8'}
@@ -8163,6 +8220,13 @@ packages:
resolution: {integrity: sha512-dA8WbNeb2a6oQzAQ55YlT5vQAWGV9WXOsi3SskE3bcCdM0P4SDd+24zS/OCacdRq5BkdsRj9q3Pg6YyQoxIGqg==}
engines: {node: '>=4'}
unicode-emoji-json@0.8.0:
resolution: {integrity: sha512-3wDXXvp6YGoKGhS2O2H7+V+bYduOBydN1lnI0uVfr1cIdY02uFFiEH1i3kE5CCE4l6UqbLKVmEFW9USxTAMD1g==}
unicode-emoji-modifier-base@1.0.0:
resolution: {integrity: sha512-yLSH4py7oFH3oG/9K+XWrz1pSi3dfUrWEnInbxMfArOfc1+33BlGPQtLsOYwvdMy11AwUBetYuaRxSPqgkq+8g==}
engines: {node: '>=4'}
unicode-match-property-ecmascript@2.0.0:
resolution: {integrity: sha512-5kaZCrbp5mmbz5ulBkDkbY0SsPOjKqVS35VpL9ulMPfSl0J0Xsm+9Evphv9CoIZFwre7aJoa94AY6seMKGVN5Q==}
engines: {node: '>=4'}
@@ -9772,6 +9836,22 @@ snapshots:
'@faker-js/faker@10.0.0': {}
'@ferrucc-io/emoji-picker@0.0.48(@babel/core@7.28.5)(@babel/template@7.27.2)(@types/react@18.3.17)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(tailwindcss@3.4.17)':
dependencies:
'@tanstack/react-virtual': 3.13.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
clsx: 2.1.1
jotai: 2.17.1(@babel/core@7.28.5)(@babel/template@7.27.2)(@types/react@18.3.17)(react@18.3.1)
node-emoji: 2.2.0
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
tailwind-merge: 2.6.0
tailwindcss: 3.4.17
unicode-emoji-json: 0.8.0
transitivePeerDependencies:
- '@babel/core'
- '@babel/template'
- '@types/react'
'@floating-ui/core@1.7.3':
dependencies:
'@floating-ui/utils': 0.2.10
@@ -11533,6 +11613,8 @@ snapshots:
'@shikijs/vscode-textmate@10.0.2': {}
'@sindresorhus/is@4.6.0': {}
'@standard-schema/spec@1.0.0': {}
'@standard-schema/spec@1.1.0': {}
@@ -12001,8 +12083,16 @@ snapshots:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@tanstack/react-virtual@3.13.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@tanstack/virtual-core': 3.13.18
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@tanstack/table-core@8.21.3': {}
'@tanstack/virtual-core@3.13.18': {}
'@testing-library/dom@10.4.1':
dependencies:
'@babel/code-frame': 7.27.1
@@ -13094,6 +13184,8 @@ snapshots:
ansi-styles: 4.3.0
supports-color: 7.2.0
char-regex@1.0.2: {}
character-entities-html4@2.1.0: {}
character-entities-legacy@3.0.0: {}
@@ -13737,6 +13829,8 @@ snapshots:
emoji-regex@9.2.2: {}
emojilib@2.4.0: {}
emojis-list@3.0.0: {}
endent@2.1.0:
@@ -15018,6 +15112,13 @@ snapshots:
jiti@2.6.1: {}
jotai@2.17.1(@babel/core@7.28.5)(@babel/template@7.27.2)(@types/react@18.3.17)(react@18.3.1):
optionalDependencies:
'@babel/core': 7.28.5
'@babel/template': 7.27.2
'@types/react': 18.3.17
react: 18.3.1
js-tokens@4.0.0: {}
js-yaml@4.1.0:
@@ -15886,6 +15987,13 @@ snapshots:
node-abort-controller@3.1.1: {}
node-emoji@2.2.0:
dependencies:
'@sindresorhus/is': 4.6.0
char-regex: 1.0.2
emojilib: 2.4.0
skin-tone: 2.0.0
node-fetch-h2@2.3.0:
dependencies:
http2-client: 1.3.5
@@ -17186,6 +17294,10 @@ snapshots:
dependencies:
jsep: 1.4.0
skin-tone@2.0.0:
dependencies:
unicode-emoji-modifier-base: 1.0.0
slash@3.0.0: {}
sonner@2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
@@ -17701,6 +17813,10 @@ snapshots:
unicode-canonical-property-names-ecmascript@2.0.1: {}
unicode-emoji-json@0.8.0: {}
unicode-emoji-modifier-base@1.0.0: {}
unicode-match-property-ecmascript@2.0.0:
dependencies:
unicode-canonical-property-names-ecmascript: 2.0.1

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

View File

@@ -19,6 +19,8 @@ const SCOPE_DESCRIPTIONS: { [key in APIKeyPermission]: string } = {
IDENTITY: "View your user ID, e-mail, and timezone",
EXECUTE_GRAPH: "Run your agents",
READ_GRAPH: "View your agents and their configurations",
WRITE_GRAPH: "Create agent graphs",
WRITE_LIBRARY: "Add agents to your library",
EXECUTE_BLOCK: "Execute individual blocks",
READ_BLOCK: "View available blocks",
READ_STORE: "Access the Marketplace",

View File

@@ -63,8 +63,19 @@ const CustomEdge = ({
return (
<>
<path
d={edgePath}
fill="none"
stroke="black"
strokeOpacity={0}
strokeWidth={20}
className="react-flow__edge-interaction cursor-pointer"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
/>
<BaseEdge
path={edgePath}
interactionWidth={0}
markerEnd={markerEnd}
className={cn(
isStatic && "!stroke-[1.5px] [stroke-dasharray:6]",

View File

@@ -127,7 +127,10 @@ export const Block: BlockComponent = ({
// preview when user drags it
const dragPreview = document.createElement("div");
dragPreview.style.cssText = blockDragPreviewStyle;
dragPreview.textContent = beautifyString(title || "");
dragPreview.textContent = beautifyString(title || "").replace(
/ Block$/,
"",
);
document.body.appendChild(dragPreview);
e.dataTransfer.setDragImage(dragPreview, 0, 0);
@@ -162,7 +165,10 @@ export const Block: BlockComponent = ({
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
)}
>
{highlightText(beautifyString(title), highlightedText)}
{highlightText(
beautifyString(title).replace(/ Block$/, ""),
highlightedText,
)}
</span>
)}
{description && (

View File

@@ -2,7 +2,7 @@ import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore"
import { FilterChip } from "../FilterChip";
import { categories } from "./constants";
import { FilterSheet } from "../FilterSheet/FilterSheet";
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
import { CategoryKey } from "./types";
export const BlockMenuFilters = () => {
const {
@@ -15,7 +15,7 @@ export const BlockMenuFilters = () => {
removeCreator,
} = useBlockMenuStore();
const handleFilterClick = (filter: GetV2BuilderSearchFilterAnyOfItem) => {
const handleFilterClick = (filter: CategoryKey) => {
if (filters.includes(filter)) {
removeFilter(filter);
} else {

View File

@@ -1,15 +1,15 @@
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
import { CategoryKey } from "./types";
export const categories: Array<{ key: CategoryKey; name: string }> = [
{ key: GetV2BuilderSearchFilterAnyOfItem.blocks, name: "Blocks" },
{ key: SearchEntryFilterAnyOfItem.blocks, name: "Blocks" },
{
key: GetV2BuilderSearchFilterAnyOfItem.integrations,
key: SearchEntryFilterAnyOfItem.integrations,
name: "Integrations",
},
{
key: GetV2BuilderSearchFilterAnyOfItem.marketplace_agents,
key: SearchEntryFilterAnyOfItem.marketplace_agents,
name: "Marketplace agents",
},
{ key: GetV2BuilderSearchFilterAnyOfItem.my_agents, name: "My agents" },
{ key: SearchEntryFilterAnyOfItem.my_agents, name: "My agents" },
];

View File

@@ -1,4 +1,4 @@
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
import { SearchEntryFilterAnyOfItem } from "@/app/api/__generated__/models/searchEntryFilterAnyOfItem";
export type DefaultStateType =
| "suggestion"
@@ -10,7 +10,7 @@ export type DefaultStateType =
| "marketplace_agents"
| "my_agents";
export type CategoryKey = GetV2BuilderSearchFilterAnyOfItem;
export type CategoryKey = SearchEntryFilterAnyOfItem;
export interface Filters {
categories: {

View File

@@ -23,7 +23,7 @@ import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { useToast } from "@/components/molecules/Toast/use-toast";
import * as Sentry from "@sentry/nextjs";
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
import { CategoryCounts } from "../BlockMenuFilters/types";
export const useBlockMenuSearchContent = () => {
const {
@@ -67,7 +67,7 @@ export const useBlockMenuSearchContent = () => {
page_size: 8,
search_query: searchQuery,
search_id: searchId,
filter: filters.length > 0 ? filters : undefined,
filter: filters.length > 0 ? filters.join(",") : undefined,
by_creator: creators.length > 0 ? creators : undefined,
},
{
@@ -117,10 +117,7 @@ export const useBlockMenuSearchContent = () => {
}
const latestData = okData(searchQueryData.pages.at(-1));
setCategoryCounts(
(latestData?.total_items as Record<
GetV2BuilderSearchFilterAnyOfItem,
number
>) || {
(latestData?.total_items as CategoryCounts) || {
blocks: 0,
integrations: 0,
marketplace_agents: 0,

View File

@@ -1,7 +1,7 @@
import { useBlockMenuStore } from "@/app/(platform)/build/stores/blockMenuStore";
import { useState } from "react";
import { INITIAL_CREATORS_TO_SHOW } from "./constant";
import { GetV2BuilderSearchFilterAnyOfItem } from "@/app/api/__generated__/models/getV2BuilderSearchFilterAnyOfItem";
import { CategoryKey } from "../BlockMenuFilters/types";
export const useFilterSheet = () => {
const { filters, creators_list, creators, setFilters, setCreators } =
@@ -9,15 +9,13 @@ export const useFilterSheet = () => {
const [isOpen, setIsOpen] = useState(false);
const [localCategories, setLocalCategories] =
useState<GetV2BuilderSearchFilterAnyOfItem[]>(filters);
useState<CategoryKey[]>(filters);
const [localCreators, setLocalCreators] = useState<string[]>(creators);
const [displayedCreatorsCount, setDisplayedCreatorsCount] = useState(
INITIAL_CREATORS_TO_SHOW,
);
const handleLocalCategoryChange = (
category: GetV2BuilderSearchFilterAnyOfItem,
) => {
const handleLocalCategoryChange = (category: CategoryKey) => {
setLocalCategories((prev) => {
if (prev.includes(category)) {
return prev.filter((c) => c !== category);

Some files were not shown because too many files have changed in this diff Show More