Compare commits

..

53 Commits

Author SHA1 Message Date
Otto
11fe703489 docs: regenerate block documentation for Linear Search Issues 2026-02-04 14:52:23 +00:00
Otto
f897e8d41f fix(platform): Improve Linear Search Block
SECRT-1880

- Add state field (with id, name, type) to Issue model for duplicate detection
- Add State model for workflow state information
- Update try_search_issues() to return createdAt, state, project, and assignee
- Add max_results parameter (default 10, was ~50) to reduce token usage
- Add team_name filter to scope results to specific team
- Fix try_get_team_by_name() to return descriptive error when team not found
- Add error output to LinearSearchIssuesBlock for graceful error handling
- Add categories to LinearSearchIssuesBlock (PRODUCTIVITY, ISSUE_TRACKING)
2026-02-04 14:28:08 +00:00
Otto
7e5b84cc5c fix(copilot): update homepage copy to focus on problem discovery (#11956)
## Summary
Update the CoPilot homepage to shift from "what do you want to
automate?" to "tell me about your problems." This lowers the barrier to
engagement by letting users describe their work frustrations instead of
requiring them to identify automations themselves.

## Changes
| Element | Before | After |
|---------|--------|-------|
| Headline | "What do you want to automate?" | "Tell me about your work
— I'll find what to automate." |
| Placeholder | "You can search or just ask - e.g. 'create a blog post
outline'" | "What's your role and what eats up most of your day? e.g.
'I'm a real estate agent and I hate...'" |
| Button 1 | "Show me what I can automate" | "I don't know where to
start, just ask me stuff" |
| Button 2 | "Design a custom workflow" | "I do the same thing every
week and it's killing me" |
| Button 3 | "Help me with content creation" | "Help me find where I'm
wasting my time" |
| Container | max-w-2xl | max-w-3xl |

> **Note on container width:** The `max-w-2xl` → `max-w-3xl` change is
just to keep the longer headline on one line. This works but may not be
the ideal solution — @lluis-xai should advise on the proper approach.

## Why This Matters
The current UX assumes users know what they want to automate. In
reality, most users know what frustrates them but can't identify
automations. The current screen blocks Otto from starting the discovery
conversation that leads to useful recommendations.

## Files Changed
- `autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx` —
headline, placeholder, container width
- `autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts` —
quick action button text

Resolves: [SECRT-1876](https://linear.app/autogpt/issue/SECRT-1876)

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-02-04 17:38:58 +07:00
Swifty
09cb313211 fix(frontend): Prevent reflected XSS in OAuth callback route (#11963)
## Summary

Fixes a reflected cross-site scripting (XSS) vulnerability in the OAuth
callback route.

**Security Issue:**
https://github.com/Significant-Gravitas/AutoGPT/security/code-scanning/202

### Vulnerability

The OAuth callback route at
`frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts`
was writing user-controlled data directly into an HTML response without
proper sanitization. This allowed potential attackers to inject
malicious scripts via OAuth callback parameters.

### Fix

Added a `safeJsonStringify()` function that escapes characters that
could break out of the script context:
- `<` → `\u003c`
- `>` → `\u003e`  
- `&` → `\u0026`

This prevents any user-provided values from being interpreted as
HTML/script content when embedded in the response.

### References

- [OWASP XSS Prevention Cheat
Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Cross_Site_Scripting_Prevention_Cheat_Sheet.html)
- [CWE-79: Improper Neutralization of Input During Web Page
Generation](https://cwe.mitre.org/data/definitions/79.html)

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified the OAuth callback still functions correctly
- [x] Confirmed special characters in OAuth responses are properly
escaped
2026-02-04 10:53:17 +01:00
Krzysztof Czerwinski
c026485023 feat(frontend): Disable auto-opening wallet (#11961)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️

- Disable auto-opening Wallet for first time user and on credit increase
- Remove no longer needed `lastSeenCredits` state and storage

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Wallet doesn't open automatically
2026-02-04 06:11:41 +00:00
Nicholas Tindle
1eabc60484 Merge commit from fork
Fixes GHSA-rc89-6g7g-v5v7 / CVE-2026-22038

The logger.info() calls were explicitly logging API keys via
get_secret_value(), exposing credentials in plaintext logs.

Changes:
- Replace info-level credential logging with debug-level provider logging
- Remove all explicit secret value logging from observe/act/extract blocks

Co-authored-by: Otto <otto@agpt.co>
2026-02-03 11:16:57 -06:00
Swifty
f4bf492f24 feat(platform): Add Redis-based SSE reconnection for long-running CoPilot operations (#11877)
## Changes 🏗️

Adds Redis-based SSE reconnection support for long-running CoPilot
operations (like Agent Generator), enabling clients to reconnect and
resume receiving updates after disconnection.

### What this does:
- **Stream Registry** - Redis-backed task tracking with message
persistence via Redis Streams
- **SSE Reconnection** - Clients can reconnect to active tasks using
`task_id` and `last_message_id`
- **Duplicate Message Fix** - Filters out in-progress assistant messages
from session response when active stream exists
- **Completion Consumer** - Handles background task completion
notifications via Redis Streams

### Architecture:
```
1. User sends message → Backend creates task in Redis
2. SSE chunks written to Redis Stream for persistence
3. Client receives chunks via SSE subscription
4. If client disconnects → Task continues in background
5. Client reconnects → GET /sessions/{id} returns active_stream info
6. Client subscribes to /tasks/{task_id}/stream with last_message_id
7. Missed messages replayed from Redis Stream
```

### Key endpoints:
- `GET /sessions/{session_id}` - Returns `active_stream` info if task is
running
- `GET /tasks/{task_id}/stream?last_message_id=X` - SSE endpoint for
reconnection
- `GET /tasks/{task_id}` - Get task status
- `POST /operations/{op_id}/complete` - Webhook for external service
completion

### Duplicate message fix:
When `GET /sessions/{id}` detects an active stream:
1. Filters out the in-progress assistant message from response
2. Returns `last_message_id="0-0"` so client replays stream from
beginning
3. Client receives complete response only through SSE (single source of
truth)

### Frontend changes:
- Task persistence in localStorage for cross-tab reconnection
- Stream event dispatcher handles reconnection flow
- Deduplication logic prevents duplicate messages

### Testing:
- Manual testing of reconnection scenarios
- Verified duplicate message fix works correctly

## Related
- Resolves SSE timeout issues for Agent Generator
- Fixes duplicate message bug on reconnection
2026-02-03 16:52:06 +01:00
Zamil Majdy
81e48c00a4 feat(copilot): add customize_agent tool for marketplace templates (#11943)
## Summary

Adds a new copilot tool that allows users to customize
marketplace/template agents using natural language before adding them to
their library.

This exposes the Agent Generator's `/api/template-modification` endpoint
to the copilot, which was previously not available.

## Changes

- **service.py**: Add `customize_template_external` to call Agent
Generator's template modification endpoint
- **core.py**: 
  - Add `customize_template` wrapper function
- Extract `graph_to_json` as a reusable function (was previously inline
in `get_agent_as_json`)
- **customize_agent.py**: New tool that:
  - Takes marketplace agent ID (format: `creator/slug`)
  - Fetches template from store via `store_db.get_agent()`
  - Calls Agent Generator for customization
  - Handles clarifying questions from the generator
  - Saves customized agent to user's library
- **__init__.py**: Register the tool in `TOOL_REGISTRY` for
auto-discovery

## Usage Flow

1. User searches marketplace: *"Find me a newsletter agent"*
2. Copilot calls `find_agent` → returns `autogpt/newsletter-writer`
3. User: *"Customize that agent to post to Discord instead of email"*
4. Copilot calls:
   ```
   customize_agent(
       agent_id="autogpt/newsletter-writer",
       modifications="Post to Discord instead of sending email"
   )
   ```
5. Agent Generator may ask clarifying questions (e.g., "What Discord
channel?")
6. Customized agent is saved to user's library

## Test plan

- [x] Verified tool imports correctly
- [x] Verified tool is registered in `TOOL_REGISTRY`
- [x] Verified OpenAI function schema is valid
- [x] Ran existing tests (`pytest backend/api/features/chat/tools/`) -
all pass
- [x] Type checker (`pyright`) passes with 0 errors
- [ ] Manual testing with copilot (requires Agent Generator service)
2026-02-03 14:59:25 +00:00
Otto
7dc53071e8 fix(backend): Add retry and error handling to block initialization (#11946)
## Summary
Adds retry logic and graceful error handling to `initialize_blocks()` to
prevent transient DB errors from crashing server startup.

## Problem
When a transient database error occurs during block initialization
(e.g., Prisma P1017 "Server has closed the connection"), the entire
server fails to start. This is overly aggressive since:
1. Blocks are already registered in memory
2. The DB sync is primarily for tracking/schema storage
3. One flaky connection shouldn't prevent the server from starting

**Triggered by:** [Sentry
AUTOGPT-SERVER-7PW](https://significant-gravitas.sentry.io/issues/7238733543/)

## Solution
- Add retry decorator (3 attempts with exponential backoff) for DB
operations
- On failure after retries, log a warning and continue to the next block
- Blocks remain available in memory even if DB sync fails
- Log summary of any failed blocks at the end

## Changes
- `autogpt_platform/backend/backend/data/block.py`: Wrap block DB sync
in retry logic with graceful fallback

## Testing
- Existing block initialization behavior unchanged on success
- On transient DB errors: retries up to 3 times, then continues with
warning
2026-02-03 12:43:30 +00:00
Zamil Majdy
4878665c66 Merge branch 'master' into dev 2026-02-03 16:01:23 +04:00
Zamil Majdy
678ddde751 refactor(backend): unify context compression into compress_context() (#11937)
## Background

This PR consolidates and unifies context window management for the
CoPilot backend.

### Problem
The CoPilot backend had **two separate implementations** of context
window management:

1. **`service.py` → `_manage_context_window()`** - Chat service
streaming/continuation
2. **`prompt.py` → `compress_prompt()`** - Sync LLM blocks

This duplication led to inconsistent behavior, maintenance burden, and
duplicate code.

---

## Solution: Unified `compress_context()`

A single async function that handles both use cases:

| Caller | Usage | Behavior |
|--------|-------|----------|
| **Chat service** | `compress_context(msgs, client=openai_client)` |
Summarization → Truncation |
| **LLM blocks** | `compress_context(msgs, client=None)` | Truncation
only (no API call) |

---

## Strategy Order

| Step | Description | Runs When |
|------|-------------|-----------|
| **1. LLM Summarization** | Summarize old messages into single context
message, keep recent 15 | Only if `client` provided |
| **2. Content Truncation** | Progressively truncate message content
(8192→4096→...→128 tokens) | If still over limit |
| **3. Middle-out Deletion** | Delete messages one at a time from center
outward | If still over limit |
| **4. First/Last Trim** | Truncate system prompt and last message
content | Last resort |

### Why This Order?

1. **Summarization first** (if available) - Preserves semantic meaning
of old messages
2. **Content truncation before deletion** - Keeps all conversation
turns, just shorter
3. **Middle-out deletion** - More granular than dropping all old
messages at once
4. **First/last trim** - Only touch system prompt as last resort

---

## Key Fixes

| Issue | Before | After |
|-------|--------|-------|
| **Socket leak** | `AsyncOpenAI` client never closed | `async with`
context manager |
| **Timeout ignored** | `timeout=30` passed to `create()` (invalid) |
`client.with_options(timeout=30)` |
| **OpenAI tool messages** | Not truncated | Properly truncated |
| **Tool pair integrity** | OpenAI format only | Both OpenAI + Anthropic
formats |

---

## Tool Format Support

`_ensure_tool_pairs_intact()` now supports both formats:

### OpenAI Format
```python
# Assistant with tool_calls
{"role": "assistant", "tool_calls": [{"id": "call_1", ...}]}
# Tool response
{"role": "tool", "tool_call_id": "call_1", "content": "result"}
```

### Anthropic Format
```python
# Assistant with tool_use
{"role": "assistant", "content": [{"type": "tool_use", "id": "toolu_1", ...}]}
# Tool result
{"role": "user", "content": [{"type": "tool_result", "tool_use_id": "toolu_1", ...}]}
```

---

## Files Changed

| File | Change |
|------|--------|
| `backend/util/prompt.py` | +450 lines: Add `CompressResult`,
`compress_context()`, helpers |
| `backend/api/features/chat/service.py` | -380 lines: Remove duplicate,
use thin wrapper |
| `backend/blocks/llm.py` | Migrate `llm_call()` to use
`compress_context(client=None)` |
| `backend/util/prompt_test.py` | +400 lines: Comprehensive tests
(OpenAI + Anthropic) |

### Removed
- `compress_prompt()` - Replaced by `compress_context(client=None)`
- `_manage_context_window()` - Replaced by
`compress_context(client=openai_client)`

---

## API

```python
async def compress_context(
    messages: list[dict],
    target_tokens: int = 120_000,
    *,
    model: str = "gpt-4o",
    client: AsyncOpenAI | None = None,  # None = truncation only
    keep_recent: int = 15,
    reserve: int = 2_048,
    start_cap: int = 8_192,
    floor_cap: int = 128,
) -> CompressResult:
    ...

@dataclass
class CompressResult:
    messages: list[dict]
    token_count: int
    was_compacted: bool
    error: str | None = None
    original_token_count: int = 0
    messages_summarized: int = 0
    messages_dropped: int = 0
```

---

## Tests Added

| Test Class | Coverage |
|------------|----------|
| `TestMsgTokens` | Token counting for regular messages, OpenAI tool
calls, Anthropic tool_use |
| `TestTruncateToolMessageContent` | OpenAI + Anthropic tool message
truncation |
| `TestEnsureToolPairsIntact` | OpenAI format (3 tests), Anthropic
format (3 tests), edge cases (3 tests) |
| `TestCompressContext` | No compression, truncation-only, tool pair
preservation, error handling |

---

## Checklist

- [x] Code follows project conventions
- [x] Linting passes (`poetry run format`)
- [x] Type checking passes (`pyright`)
- [x] Tests added for all new functions
- [x] Both OpenAI and Anthropic tool formats supported
- [x] Backward compatible behavior preserved
- [x] All review comments addressed
2026-02-03 10:36:10 +00:00
Otto
aef6f57cfd fix(scheduler): route db calls through DatabaseManager (#11941)
## Summary

Routes `increment_onboarding_runs` and `cleanup_expired_oauth_tokens`
through the DatabaseManager RPC client instead of calling Prisma
directly.

## Problem

The Scheduler service never connects its Prisma client. While
`add_graph_execution()` in `utils.py` has a fallback that routes through
DatabaseManager when Prisma isn't connected, subsequent calls in the
scheduler were hitting Prisma directly:

- `increment_onboarding_runs()` after successful graph execution
- `cleanup_expired_oauth_tokens()` in the scheduled job

These threw `ClientNotConnectedError`, caught by generic exception
handlers but spamming Sentry (~696K events since December per the
original analysis in #11926).

## Solution

Follow the same pattern as `utils.py`:
1. Add `cleanup_expired_oauth_tokens` to `DatabaseManager` and
`DatabaseManagerAsyncClient`
2. Update scheduler to use `get_database_manager_async_client()` for
both calls

## Changes

- **database.py**: Import and expose `cleanup_expired_oauth_tokens` in
both manager classes
- **scheduler.py**: Use `db.increment_onboarding_runs()` and
`db.cleanup_expired_oauth_tokens()` via the async client

## Impact

- Eliminates Sentry error spam from scheduler
- Onboarding run counters now actually increment for scheduled
executions
- OAuth token cleanup now actually runs

## Testing

Deploy to staging with scheduled graphs and verify:
1. No more `ClientNotConnectedError` in scheduler logs
2. `UserOnboarding.agentRuns` increments on scheduled runs
3. Expired OAuth tokens get cleaned up

Refs: #11926 (original fix that was closed)
2026-02-03 09:54:49 +00:00
Krzysztof Czerwinski
14cee1670a fix(backend): Prevent leaking Redis connections in ws_api (#11869)
Fixing
https://github.com/Significant-Gravitas/AutoGPT/pull/11297#discussion_r2496833421

### Changes 🏗️

1. event_bus.py - Added close method to AsyncRedisEventBus
- Added __init__ method to track the _pubsub instance attribute
- Added async def close() method that closes the PubSub connection
safely
- Modified listen_events() to store the pubsub reference in self._pubsub

2. ws_api.py - Added cleanup in event_broadcaster
- Wrapped the worker coroutines in try/finally block
- The finally block calls close() on both event buses to ensure cleanup
happens on any exit (including exceptions before retry)
2026-02-03 08:07:48 +00:00
Zamil Majdy
d81d1ce024 refactor(backend): extract context window management and fix LLM continuation (#11936)
## Summary

Fixes CoPilot becoming unresponsive after long-running tools complete,
and refactors context window management into a reusable function.

## Problem

After `create_agent` completes, `_generate_llm_continuation()` was
sending ALL messages to OpenRouter without any context compaction. When
conversations exceeded ~50 messages, OpenRouter rejected requests with
`provider_name: 'unknown'` (no provider would accept).

**Evidence:** Langfuse session
[44fbb803-092e-4ebd-b288-852959f4faf5](https://cloud.langfuse.com/project/cmk5qhf210003ad079sd8utjt/sessions/44fbb803-092e-4ebd-b288-852959f4faf5)
showed:
- Successful calls: 32-50 messages, known providers
- Failed calls: 52+ messages, `provider: unknown`, `completion: null`

## Changes

### Refactor: Extract reusable `_manage_context_window()`
- Counts tokens and checks against 120k threshold
- Summarizes old messages while keeping recent 15
- Ensures tool_call/tool_response pairs stay intact
- Progressive truncation if still over limit
- Returns `ContextWindowResult` dataclass with messages, token count,
compaction status, and errors
- Helper `_messages_to_dicts()` reduces code duplication

### Fix: Update `_generate_llm_continuation()`
- Now calls `_manage_context_window()` before making LLM calls
- Adds retry logic with exponential backoff (matching
`_stream_chat_chunks` behavior)

### Cleanup: Update `_stream_chat_chunks()`
- Replaced inline context management with call to
`_manage_context_window()`
- Eliminates code duplication between the two functions

## Testing

- Syntax check: 
- Ruff lint: 
- Import verification: 

## Checklist

- [x] My code follows the style guidelines of this project
- [x] I have performed a self-review of my own code
- [x] My changes generate no new warnings
- [x] I have checked that my changes do not break existing functionality

---------

Co-authored-by: Otto <otto@agpt.co>
2026-02-03 04:41:43 +00:00
Zamil Majdy
2dd341c369 refactor: enrich description with context before calling Agent Generator (#11932)
## Summary
Updates the Agent Generator client to enrich the description with
context before calling, instead of sending `user_instruction` as a
separate parameter.

## Context
Companion PR to Significant-Gravitas/AutoGPT-Agent-Generator#105 which
removes unused parameters from the decompose API.

## Changes
- Enrich `description` with `context` (e.g., clarifying question
answers) before sending
- Remove `user_instruction` from request payload

## How it works
Both input boxes and chat box work the same way - the frontend
constructs a formatted message with answers and sends it as a user
message. The backend then enriches the description with this context
before calling the external Agent Generator service.
2026-02-03 02:31:07 +00:00
Otto
f7350c797a fix(copilot): use messages_dict in fallback context compaction (#11922)
## Summary

Fixes a bug where the fallback path in context compaction passes
`recent_messages` (already sliced) instead of `messages_dict` (full
conversation) to `_ensure_tool_pairs_intact`.

This caused the function to fail to find assistant messages that exist
in the original conversation but were outside the sliced window,
resulting in orphan tool_results being sent to Anthropic and rejected
with:

```
messages.66.content.0: unexpected tool_use_id found in tool_result blocks: toolu_vrtx_019bi1PDvEn7o5ByAxcS3VdA
```

## Changes

- Pass `messages_dict` and `slice_start` (relative to full conversation)
instead of `recent_messages` and `reduced_slice_start` (relative to
already-sliced list)

## Testing

This is a targeted fix for the fallback path. The bug only manifests
when:
1. Token count > 120k (triggers compaction)
2. Initial compaction + summary still exceeds limit (triggers fallback)
3. A tool_result's corresponding assistant is in `messages_dict` but not
in `recent_messages`

## Related

- Fixes SECRT-1861
- Related: SECRT-1839 (original fix that missed this code path)
2026-02-02 13:01:05 +00:00
Guofang.Tang
1081590384 feat(backend): cover webhook ingress URL route (#11747)
### Changes 🏗️

- Add a unit test to verify webhook ingress URL generation matches the
FastAPI route.

  ### Checklist 📋

  #### For code changes:

  - [x] I have clearly listed my changes in the PR description
  - [x] I have made a test plan
  - [x] I have tested my changes according to the test plan:
- [x] poetry run pytest backend/integrations/webhooks/utils_test.py
--confcutdir=backend/integrations/webhooks

  #### For configuration changes:

  - [x] .env.default is updated or already compatible with my changes
- [x] docker-compose.yml is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under Changes)



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Added a unit test that validates webhook ingress URL generation
matches the application's resolved route (scheme, host, and path) for
provider-specific webhook endpoints, improving confidence in routing
behavior and helping prevent regressions.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-02-01 20:29:15 +00:00
Otto
7e37de8e30 fix: Include graph schemas for marketplace agents in Agent Generator (#11920)
## Problem

When marketplace agents are included in the `library_agents` payload
sent to the Agent Generator service, they were missing required fields
(`graph_id`, `graph_version`, `input_schema`, `output_schema`). This
caused Pydantic validation to fail with HTTP 422 Unprocessable Entity.

**Root cause:** The `MarketplaceAgentSummary` TypedDict had a different
shape than `LibraryAgentInfo` expected by the Agent Generator:
- Agent Generator expects: `graph_id`, `graph_version`, `name`,
`description`, `input_schema`, `output_schema`
- MarketplaceAgentSummary had: `name`, `description`, `sub_heading`,
`creator`, `is_marketplace_agent`

## Solution

1. **Add `agent_graph_id` to `StoreAgent` model** - The field was
already in the database view but not exposed
2. **Include `agentGraphId` in hybrid search SQL query** - Carry the
field through the search CTEs
3. **Update `search_marketplace_agents_for_generation()`** - Now fetches
full graph schemas using `get_graph()` and returns `LibraryAgentSummary`
(same type as library agents)
4. **Update deduplication logic** - Use `graph_id` instead of name for
more accurate deduplication

## Changes

- `backend/api/features/store/model.py`: Add optional `agent_graph_id`
field to `StoreAgent`
- `backend/api/features/store/hybrid_search.py`: Include `agentGraphId`
in SQL query columns
- `backend/api/features/store/db.py`: Map `agentGraphId` when creating
`StoreAgent` objects
- `backend/api/features/chat/tools/agent_generator/core.py`: Update
`search_marketplace_agents_for_generation()` to fetch and include full
graph schemas

## Testing

- [ ] Agent creation on dev with marketplace agents in context
- [ ] Verify no 422 errors from Agent Generator
- [ ] Verify marketplace agents can be used as sub-agents

Fixes: SECRT-1817

---------

Co-authored-by: majdyz <majdyz@users.noreply.github.com>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-01-31 19:17:36 +00:00
Otto
2abbb7fbc8 hotfix(backend): use discriminator for credential matching in run_block (#11908)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 21:50:21 -06:00
Otto
7ee94d986c docs: add credentials prerequisites to create-basic-agent guide (#11913)
## Summary
Addresses #11785 - users were encountering `openai_api_key_credentials`
errors when following the create-basic-agent guide because it didn't
mention the need to configure API credentials before using AI blocks.

## Changes
Added a **Prerequisites** section to
`docs/platform/create-basic-agent.md` explaining:
- **Cloud users:** Go to Profile → Integrations to add API keys
- **Self-hosted (Docker):** Add keys to `autogpt_platform/backend/.env`
and restart services

Also added a note that the Calculator example doesn't need credentials,
making it a good first test.

## Related
- Issue: #11785
2026-01-31 03:05:31 +00:00
Nicholas Tindle
05b60db554 fix(backend/chat): Include input schema in discovery and validate unknown fields (#11916)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 21:00:43 -06:00
Zamil Majdy
18a1661fa3 feat: add library agent fetching with two-phase search for sub-agent support (#11889)
## Context

When users ask the chat to create agents, they may want to compose
workflows that reuse their existing agents as sub-agents. For this to
work, the Agent Generator service needs to know what agents the user has
available.

**Challenge:** Users can have large libraries with many agents. Fetching
all of them would be slow and provide too much context to the LLM.

## Solution

This PR implements **search-based library agent fetching** with a
**two-phase search** strategy:

1. **Phase 1 (Initial Search):** When the user describes their goal, we
search for relevant library agents using the goal as the search query
2. **Phase 2 (Step-Based Enrichment):** After the goal is decomposed
into steps, we extract keywords from those steps and search for
additional relevant agents

This ensures we find agents that are relevant to both the high-level
goal AND the specific steps identified.

### Example Flow

```
User goal: "Create an agent that fetches weather and sends a summary email"

Phase 1: Search for "weather email summary" → finds "Weather Fetcher" agent
Phase 2: After decomposition identifies steps like "send email notification"
         → searches "send email notification" → finds "Gmail Sender" agent
```

### Changes

**Library Agent Fetching:**
- `get_library_agents_for_generation()` - Search-based fetching from
user's library
- `search_marketplace_agents_for_generation()` - Search public
marketplace
- `get_all_relevant_agents_for_generation()` - Combines both with
deduplication

**Two-Phase Search:**
- `extract_search_terms_from_steps()` - Extracts keywords from
decomposed steps
- `enrich_library_agents_from_steps()` - Searches for additional agents
based on steps
- Integrated into `create_agent.py` as "Step 1.5" after goal
decomposition

**Type Safety:**
- Added `TypedDict` definitions: `LibraryAgentSummary`,
`MarketplaceAgentSummary`, `DecompositionStep`, `DecompositionResult`

### Design Decisions

- **Search-based, not fetch-all:** Scalable for large libraries
- **Library agents prioritized:** They have full schemas; marketplace
agents have basic info only
- **Deduplication by name and graph_id:** Prevents duplicates across
searches
- **Graceful degradation:** Failures don't block agent generation
- **Limited to 3 search terms:** Avoids excessive API calls during
enrichment

## Related PR
- Agent Generator:
https://github.com/Significant-Gravitas/AutoGPT-Agent-Generator/pull/103

## Test plan
- [x] `test_library_agents.py` - 19 tests covering all new functions
- [x] `test_service.py` - 4 tests for library_agents passthrough
- [ ] Integration test: Create agent with library sub-agent composition
2026-01-31 00:18:21 +00:00
Otto
b72521daa9 fix(readme): update broken self-hosting docs link (#11911)
## Summary
The self-hosting guide link in README.md was broken.

**Old link:** `https://docs.agpt.co/platform/getting-started/`
- Redirects to `https://agpt.co/docs/platform/getting-started`
- Returns HTTP 400 

**New link:**
`https://agpt.co/docs/platform/getting-started/getting-started`
- Works correctly 

## Changes
- Updated the self-hosting guide URL in README.md

Fixes #OPEN-2973
2026-01-30 22:59:45 +00:00
Ubbe
cc4839bedb hotfix(frontend): fix home redirect (3) (#11904)
### Changes 🏗️

Further improvements to LaunchDarkly initialisation and homepage
redirect...

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Run the app locally with the flag disabled/enabled, and the
redirects work

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Ubbe <0ubbe@users.noreply.github.com>
2026-01-30 20:40:46 +07:00
Otto
dbbff04616 hotfix(frontend): LD remount (#11903)
## Changes 🏗️

Removes the `key` prop from `LDProvider` that was causing full remounts
when user context changed.

### The Problem

The `key={context.key}` prop was forcing React to unmount and remount
the entire LDProvider when switching from anonymous → logged in user:

```
1. Page loads, user loading → key="anonymous" → LD mounts → flags available 
2. User finishes loading → key="user-123" → React sees key changed
3. LDProvider UNMOUNTS → flags become undefined 
4. New LDProvider MOUNTS → initializes again → flags available 
```

This caused the flag values to cycle: `undefined → value → undefined →
value`

### The Fix

Remove the `key` prop. The LDProvider handles context changes internally
via the `context` prop, which triggers `identify()` without remounting
the provider.

## Checklist 📋

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  - [ ] Flag values don't flicker on page load
  - [ ] Flag values update correctly when logging in/out
  - [ ] No redirect race conditions

Related: SECRT-1845
2026-01-30 19:08:26 +07:00
Reinier van der Leer
350ad3591b fix(backend/chat): Filter credentials for graph execution by scopes (#11881)
[SECRT-1842: run_agent tool does not correctly use credentials - agents
fail with insufficient auth
scopes](https://linear.app/autogpt/issue/SECRT-1842)

### Changes 🏗️

- Include scopes in credentials filter in
`backend.api.features.chat.tools.utils.match_user_credentials_to_graph`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI must pass
- It's broken now and a simple change so we'll test in the dev
deployment
2026-01-30 11:01:51 +00:00
Ubbe
e6438b9a76 hotfix(frontend): use server redirect (#11900)
### Changes 🏗️

The page used a client-side redirect (`useEffect` + `router.replace`)
which only works after JavaScript loads and hydrates. On deployed sites,
if there's any delay or failure in JS execution, users see an
empty/black page because the component returns null.

**Fix:** Converted to a server-side redirect using redirect() from
next/navigation. This is a server component now, so:

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Tested locally but will see it fully working once deployed
2026-01-30 17:20:03 +07:00
Bently
de0ec3d388 chore(llm): remove deprecated Claude 3.7 Sonnet model with migration and defensive handling (#11841)
## Summary
Remove `claude-3-7-sonnet-20250219` from LLM model definitions ahead of
Anthropic's API retirement, with comprehensive migration and defensive
error handling.

## Background
Anthropic is retiring Claude 3.7 Sonnet (`claude-3-7-sonnet-20250219`)
on **February 19, 2026 at 9:00 AM PT**. This PR removes the model from
the platform and migrates existing users to prevent service
interruptions.

## Changes

### Code Changes
- Remove `CLAUDE_3_7_SONNET` enum member from `LlmModel` in `llm.py`
- Remove corresponding `ModelMetadata` entry
- Remove `CLAUDE_3_7_SONNET` from `StagehandRecommendedLlmModel` enum
- Remove `CLAUDE_3_7_SONNET` from block cost config
- Add `CLAUDE_4_5_SONNET` to `StagehandRecommendedLlmModel` enum
- Update Stagehand block defaults from `CLAUDE_3_7_SONNET` to
`CLAUDE_4_5_SONNET` (staying in Claude family)
- Add defensive error handling in `CredentialsFieldInfo.discriminate()`
for deprecated model values

### Database Migration
- Adds migration `20260126120000_migrate_claude_3_7_to_4_5_sonnet`
- Migrates `AgentNode.constantInput` model references
- Migrates `AgentNodeExecutionInputOutput.data` preset overrides

### Documentation
- Updated `docs/integrations/block-integrations/llm.md` to remove
deprecated model
- Updated `docs/integrations/block-integrations/stagehand/blocks.md` to
remove deprecated model and add Claude 4.5 Sonnet

## Notes
- Agent JSON files in `autogpt_platform/backend/agents/` still reference
this model in their provider mappings. These are auto-generated and
should be regenerated separately.

## Testing
- [ ] Verify LLM block still functions with remaining models
- [ ] Confirm no import errors in affected files
- [ ] Verify migration runs successfully
- [ ] Verify deprecated model gives helpful error message instead of
KeyError
2026-01-30 08:40:55 +00:00
Otto
e10ff8d37f fix(frontend): remove double flag check on homepage redirect (#11894)
## Changes 🏗️

Fixes the hard refresh redirect bug (SECRT-1845) by removing the double
feature flag check.

### Before (buggy)
```
/                    → checks flag → /copilot or /library
/copilot (layout)    → checks flag → /library if OFF
```

On hard refresh, two sequential LD checks created a race condition
window.

### After (fixed)
```
/                    → always redirects to /copilot
/copilot (layout)    → single flag check via FeatureFlagPage
```

Single check point = no double-check race condition.

## Root Cause

As identified by @0ubbe: the root page and copilot layout were both
checking the feature flag. On hard refresh with network latency, the
second check could fire before LaunchDarkly fully initialized, causing
users to be bounced to `/library`.

## Test Plan

- [ ] Hard refresh on `/` → should go to `/copilot` (flag ON)
- [ ] Hard refresh on `/copilot` → should stay on `/copilot` (flag ON)  
- [ ] With flag OFF → should redirect to `/library`
- [ ] Normal navigation still works

Fixes: SECRT-1845

cc @0ubbe
2026-01-30 08:32:50 +00:00
Otto
7cb1e588b0 fix(frontend): Refocus ChatInput after voice transcription completes (#11893)
## Summary
Refocuses the chat input textarea after voice transcription finishes,
allowing users to immediately use `spacebar+enter` to record and send
their prompt.

## Changes
- Added `inputId` parameter to `useVoiceRecording` hook
- After transcription completes, the input is automatically focused
- This improves the voice input UX flow

## Testing
1. Click mic button or press spacebar to record voice
2. Record a message and stop
3. After transcription completes, the input should be focused
4. User can now press Enter to send or spacebar to record again

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-01-30 14:49:05 +07:00
Otto
582c6cad36 fix(e2e): Make E2E test data deterministic and fix flaky tests (#11890)
## Summary
Fixes flaky E2E marketplace and library tests that were causing PRs to
be removed from the merge queue.

## Root Cause
1. **Test data was probabilistic** - `e2e_test_data.py` used random
chances (40% approve, then 20-50% feature), which could result in 0
featured agents
2. **Library pagination threshold wrong** - Checked `>= 10`, but page
size is 20
3. **Fixed timeouts** - Used `waitForTimeout(2000)` /
`waitForTimeout(10000)` instead of proper waits

## Changes

### Backend (`e2e_test_data.py`)
- Add guaranteed minimums: 8 featured agents, 5 featured creators, 10
top agents
- First N submissions are deterministically approved and featured
- Increase agents per user from 15 → 25 (for pagination with
page_size=20)
- Fix library agent creation to use constants instead of hardcoded `10`

### Frontend Tests
- `library.spec.ts`: Fix pagination threshold to `PAGE_SIZE` (20)
- `library.page.ts`: Replace 2s timeout with `networkidle` +
`waitForFunction`
- `marketplace.page.ts`: Add `networkidle` wait, 30s waits in
`getFirst*` methods
- `marketplace.spec.ts`: Replace 10s timeout with `waitForFunction`
- `marketplace-creator.spec.ts`: Add `networkidle` + element waits

## Related
- Closes SECRT-1848, SECRT-1849
- Should unblock #11841 and other PRs in merge queue

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
2026-01-30 05:12:35 +00:00
Nicholas Tindle
3b822cdaf7 chore(branchlet): Remove docs pip install from postCreateCmd (#11883)
### Changes 🏗️

- Removed `cd docs && pip install -r requirements.txt` from
`postCreateCmd` in `.branchlet.json`
- Docs dependencies will no longer be auto-installed during branchlet
worktree creation

### Rationale

The docs setup step was adding unnecessary overhead to the worktree
creation process. Developers who need to work on documentation can
manually install the docs requirements when needed.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified branchlet worktree creation still works without the docs
pip install step

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2026-01-30 00:31:34 +00:00
Zamil Majdy
b2eb4831bd feat(chat): improve agent generator error propagation (#11884)
## Summary
- Add helper functions in `service.py` to create standardized error
responses with `error_type` classification
- Update service functions to return error dicts instead of `None`,
preserving error details from the Agent Generator microservice
- Update `core.py` to pass through error responses properly
- Update `create_agent.py` to handle error responses with user-friendly
messages based on error type

## Error Types Now Propagated
| Error Type | Description | User Message |
|------------|-------------|--------------|
| `llm_parse_error` | LLM returned unparseable response | "The AI had
trouble understanding this request" |
| `llm_timeout` / `timeout` | Request timed out | "The request took too
long" |
| `llm_rate_limit` / `rate_limit` | Rate limited | "The service is
currently busy" |
| `validation_error` | Agent validation failed | "The generated agent
failed validation" |
| `connection_error` | Could not connect to Agent Generator | Generic
error message |
| `http_error` | HTTP error from Agent Generator | Generic error message
|
| `unknown` | Unclassified error | Generic error message |

## Motivation
This enables better debugging for issues like SECRT-1817 where
decomposition failed due to transient LLM errors but the root cause was
unclear in the logs. Now:
1. Error details from the Agent Generator microservice are preserved
2. Users get more helpful error messages based on error type
3. Debugging is easier with `error_type` in response details

## Related PR
- Agent Generator side:
https://github.com/Significant-Gravitas/AutoGPT-Agent-Generator/pull/102

## Test Plan
- [ ] Test decomposition with various error scenarios (timeout, parse
error)
- [ ] Verify user-friendly messages are shown based on error type
- [ ] Check that error details are logged properly
2026-01-29 19:53:40 +00:00
Reinier van der Leer
4cd5da678d refactor(claude): Split autogpt_platform/CLAUDE.md into project-specific files (#11788)
Split `autogpt_platform/CLAUDE.md` into project-specific files, to make
the scope of the instructions clearer.

Also, some minor improvements:

- Change references to other Markdown files to @file/path.md syntax that
Claude recognizes
- Update ambiguous/incorrect/outdated instructions
- Remove trailing slashes
- Fix broken file path references in other docs (including comments)
2026-01-29 17:33:02 +00:00
Ubbe
9538992eaf hotfix(frontend): flags copilot redirects (#11878)
## Changes 🏗️

- Refactor homepage redirect logic to always point to `/`
- the `/` route handles whether to redirect to `/copilot` or `/library`
based on flag
- Simplify `useGetFlag` checks
- Add `<FeatureFlagRedirect />` and `<FeatureFlagPage />` wrapper
components
- helpers to do 1 thing or the other, depending on chat enabled/disabled
- avoids boilerplate code, checking flagss and redirects mistakes
(especially around race conditions with LD init )

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Log in / out of AutoGPT with flag disabled/enabled
  - [x] Sign up to AutoGPT with flag disabled/enabled
  - [x] Redirects to homepage always work `/`
  - [x] Can't access Copilot with disabled flag
2026-01-29 18:13:28 +07:00
Ubbe
b94c83aacc feat(frontend): Copilot speech to text via Whisper model (#11871)
## Changes 🏗️


https://github.com/user-attachments/assets/d9c12ac0-625c-4b38-8834-e494b5eda9c0

Add a "speech to text" feature in the Chat input fox of Copilot, similar
as what you have in ChatGPT.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Run locally and try the speech to text feature as part of the chat
input box

### For configuration changes:

We need to add `OPENAI_API_KEY=` to Vercel ( used in the Front-end )
both in Dev and Prod.

- [x] `.env.default` is updated or already compatible with my changes

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-29 17:46:36 +07:00
Nicholas Tindle
7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00
Nicholas Tindle
27b72062f2 Merge branch 'dev' 2026-01-28 15:17:57 -06:00
Nicholas Tindle
e0dfae5732 fix(platform): evaluate chat flag after auth for correct redirect (#11873)
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 14:58:02 -06:00
Zamil Majdy
9a79a8d257 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT 2026-01-28 12:32:17 -06:00
Zamil Majdy
7df867d645 Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-01-28 12:29:41 -06:00
Zamil Majdy
a9bf08748b Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT 2026-01-28 12:28:48 -06:00
Zamil Majdy
d855f79874 fix(platform): reduce Sentry alert spam for expected errors (#11872)
## Summary
- Add `InvalidInputError` for validation errors (search term too long,
invalid pagination) - returns 400 instead of 500
- Remove redundant try/catch blocks in library routes - global exception
handlers already handle `ValueError`→400 and `NotFoundError`→404
- Aggregate embedding backfill errors and log once at the end instead of
per content type to prevent Sentry issue spam

## Test plan
- [x] Verify validation errors (search term >100 chars) return 400 Bad
Request
- [x] Verify NotFoundError still returns 404
- [x] Verify embedding errors are logged once at the end with aggregated
counts

Fixes AUTOGPT-SERVER-7K5, BUILDER-6NC

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
2026-01-29 01:28:27 +07:00
Swifty
dac99694fe Merge branch 'release/v0.6.44' 2026-01-28 12:19:13 +01:00
Nicholas Tindle
0953983944 feat(platform): disable onboarding redirects and add $5 signup bonus (#11862)
Disable automatic onboarding redirects on signup/login while keeping the
checklist/wallet functional. Users now receive $5 (500 credits) on their
first visit to /copilot.

### Changes 🏗️

- **Frontend**: `shouldShowOnboarding()` now returns `false`, disabling
auto-redirects to `/onboarding`
- **Backend**: Added `VISIT_COPILOT` onboarding step with 500 credit
($5) reward
- **Frontend**: Copilot page automatically completes `VISIT_COPILOT`
step on mount
- **Database**: Migration to add `VISIT_COPILOT` to `OnboardingStep`
enum

NOTE: /onboarding/1-welcome -> /library now as shouldShowOnboardin is
always false

Users land directly on `/copilot` after signup/login and receive $5
invisibly (not shown in checklist UI).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] New user signup (email/password) → lands on `/copilot`, wallet
shows 500 credits
- [x] Verified credits are only granted once (idempotent via onboarding
reward mechanism)
- [x] Existing user login (already granted flag set) → lands on
`/copilot`, no duplicate credits
  - [x] Checklist/wallet remains functional

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

No configuration changes required.

---

OPEN-2967

🤖 Generated with [Claude Code](https://claude.ai/code)


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces a new onboarding step and adjusts onboarding flow.
> 
> - Adds `VISIT_COPILOT` onboarding step (+500 credits) with DB enum
migration and API/type updates
> - Copilot page auto-completes `VISIT_COPILOT` on mount to grant the
welcome bonus
> - Changes `/onboarding/enabled` to require user context and return
`false` when `CHAT` feature is enabled (skips legacy onboarding)
> - Wallet now refreshes credits on any onboarding `step_completed`
notification; confetti limited to visible tasks
> - Test flows updated to accept redirects to `copilot`/`library` and
verify authenticated state
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
ec5a5a4dfd. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2026-01-28 07:22:46 +00:00
Zamil Majdy
0058cd3ba6 fix(frontend): auto-poll for long-running tool completion (#11866)
## Summary
Fixes the issue where the "Creating Agent" spinner doesn't auto-update
when agent generation completes - user had to refresh the browser.

**Changes:**
- **Frontend polling**: Add `onOperationStarted` callback to trigger
polling when `operation_started` is received via SSE
- **Polling backoff**: 2s, 4s, 6s, 8s... up to 30s max
- **Message deduplication**: Use content-based keys (role + content)
instead of timestamps to prevent duplicate messages
- **Message ordering**: Preserve server message order instead of
timestamp-based sorting
- **Debug cleanup**: Remove verbose console.log/console.info statements

## Test plan
- [ ] Start agent generation in copilot
- [ ] Verify "Creating Agent" spinner appears
- [ ] Wait for completion (2-5 min) WITHOUT refreshing
- [ ] Verify agent carousel appears automatically when done
- [ ] Verify no duplicate messages in chat
- [ ] Verify message order is correct (user → assistant → tool_call →
tool_response)
2026-01-28 10:03:21 +07:00
Nicholas Tindle
ea035224bc feat(copilot): Increase max_agent_runs and max_agent_schedules (#11865)
<!-- Clearly explain the need for these changes: -->
Config change to increase the max times an agent can run in the chat and
the max number of scheduels created by copilot in one chat

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Increases per-chat operational limits for Copilot.
> 
> - Bumps `max_agent_runs` default from `3` to `30` in `ChatConfig`
> - Bumps `max_agent_schedules` default from `3` to `30` in `ChatConfig`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
93cbae6d27. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 01:08:02 +00:00
Nicholas Tindle
62813a1ea6 Delete backend/blocks/video/__init__.py (#11864)
<!-- Clearly explain the need for these changes: -->
oops file
### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
removes file that should have not been commited

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Removes erroneous `backend/blocks/video/__init__.py`, eliminating an
unintended `video` package.
> 
> - Deletes a placeholder comment-only file
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
3b84576c33. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2026-01-28 00:58:49 +00:00
Bently
67405f7eb9 fix(copilot): ensure tool_call/tool_response pairs stay intact during context compaction (#11863)
## Summary

Fixes context compaction breaking tool_call/tool_response pairs, causing
API validation errors.

## Problem

When context compaction slices messages with `messages[-KEEP_RECENT:]`,
a naive slice can separate an assistant message containing `tool_calls`
from its corresponding tool response messages. This causes API
validation errors like:

```
messages.0.content.1: unexpected 'tool_use_id' found in 'tool_result' blocks: orphan_12345.
Each 'tool_result' block must have a corresponding 'tool_use' block in the previous message.
```

## Solution

Added `_ensure_tool_pairs_intact()` helper function that:
1. Detects orphan tool responses in a slice (tool messages whose
`tool_call_id` has no matching assistant message)
2. Extends the slice backwards to include the missing assistant messages
3. Falls back to removing orphan tool responses if the assistant cannot
be found (edge case)

Applied this safeguard to:
- The initial `KEEP_RECENT` slice (line ~990)
- The progressive fallback slices when still over token limit (line
~1079)

## Testing

- Syntax validated with `python -m py_compile`
- Logic reviewed for correctness

## Linear

Fixes SECRT-1839

---
*Debugged by Toran & Orion in #agpt Discord*
2026-01-28 00:21:54 +00:00
Zamil Majdy
171ff6e776 feat(backend): persist long-running tool results to survive SSE disconnects (#11856)
## Summary

Agent generation (`create_agent`, `edit_agent`) can take 1-5 minutes.
Previously, if the user closed their browser tab during this time:
1. The SSE connection would die
2. The tool execution would be cancelled via `CancelledError`
3. The result would be lost - even if the agent-generator service
completed successfully

This PR ensures long-running tool operations survive SSE disconnections.

### Changes 🏗️

**Backend:**
- **base.py**: Added `is_long_running` property to `BaseTool` for tools
to opt-in to background execution
- **create_agent.py / edit_agent.py**: Set `is_long_running = True`
- **models.py**: Added `OperationStartedResponse`,
`OperationPendingResponse`, `OperationInProgressResponse` types
- **service.py**: Modified `_yield_tool_call()` to:
  - Check if tool is `is_long_running`
  - Save "pending" message to chat history immediately
  - Spawn background task that runs independently of SSE
  - Return `operation_started` immediately (don't wait)
  - Update chat history with result when background task completes
- Track running operations for idempotency (prevents duplicate ops on
refresh)
- **db.py**: Added `update_tool_message_content()` to update pending
messages
- **model.py**: Added `invalidate_session_cache()` to clear Redis after
background completion

**Frontend:**
- **useChatMessage.ts**: Added operation message types
- **helpers.ts**: Handle `operation_started`, `operation_pending`,
`operation_in_progress` response types
- **PendingOperationWidget**: New component to display operation status
with spinner
- **ChatMessage.tsx**: Render `PendingOperationWidget` for operation
messages

### How It Works

```
User Request → Save "pending" message → Spawn background task → Return immediately
                                              ↓
                                     Task runs independently of SSE
                                              ↓
                                     On completion: Update message in chat history
                                              ↓
                                     User refreshes → Loads history → Sees result
```

### User Experience

1. User requests agent creation
2. Sees "Agent creation started. You can close this tab - check your
library in a few minutes."
3. Can close browser tab safely
4. When they return, chat shows the completed result (or error)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] pyright passes (0 errors)
  - [x] TypeScript checks pass
  - [x] Formatters applied

### Test Plan

1. Start agent creation in copilot
2. Close browser tab immediately after seeing "operation_started" 
3. Wait 2-3 minutes
4. Reopen chat
5. Verify: Chat history shows completion message and agent appears in
library

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
2026-01-28 05:09:34 +07:00
Lluis Agusti
349b1f9c79 hotfix(frontend): copilot session handling refinements... 2026-01-28 02:53:45 +07:00
Lluis Agusti
277b0537e9 hotfix(frontend): copilot simplication... 2026-01-28 02:10:18 +07:00
Ubbe
071b3bb5cd fix(frontend): more copilot refinements (#11858)
## Changes 🏗️

On the **Copilot** page:

- prevent unnecessary sidebar repaints 
- show a disclaimer when switching chats on the sidebar to terminate a
current stream
- handle loading better
- save streams better when disconnecting


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and test the above
2026-01-28 00:49:28 +07:00
258 changed files with 15346 additions and 12538 deletions

View File

@@ -29,8 +29,7 @@
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
"cd autogpt_platform/frontend && pnpm install"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false

View File

@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
**Backend Entry Points:**
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/api/rest_api.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
### API Development
1. Update routes in `/backend/backend/server/routers/`
1. Update routes in `/backend/backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
4. For `data/*.py` changes, validate user ID checks
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)

2
.gitignore vendored
View File

@@ -178,4 +178,6 @@ autogpt_platform/backend/settings.py
*.ign.*
.test-contents
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next

View File

@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
- Format Python code with `poetry run format`.
- Format frontend code using `pnpm format`.
## Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
- Do not type hook returns, let Typescript infer as much as possible
- Never type with `any`, if not types available use `unknown`
## Testing
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
Always run the relevant linters and tests before committing.
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
Types:
- feat
- fix
- refactor
- ci
- dx (developer experience)
Scopes:
- platform
- platform/library
- platform/marketplace
- backend
- backend/executor
- frontend
- frontend/library
- frontend/marketplace
- blocks
Types: - feat - fix - refactor - ci - dx (developer experience)
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
## Pull requests

View File

@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
### Updated Setup Instructions:
We've moved to a fully maintained and regularly updated documentation site.
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
This tutorial assumes you have Docker, VSCode, git and npm installed.

View File

@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
AutoGPT Platform is a monorepo containing:
- **Backend** (`/backend`): Python FastAPI server with async support
- **Frontend** (`/frontend`): Next.js React application
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
- **Backend** (`backend`): Python FastAPI server with async support
- **Frontend** (`frontend`): Next.js React application
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
## Essential Commands
## Component Documentation
### Backend Development
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
```bash
# Install dependencies
cd backend && poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend server
poetry run serve
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in TESTING.md
#### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
### Frontend Development
```bash
# Install dependencies
cd frontend && pnpm i
# Generate API client from OpenAPI spec
pnpm generate:api
# Start development server
pnpm dev
# Run E2E tests
pnpm test
# Run Storybook for component development
pnpm storybook
# Build production
pnpm build
# Format and lint
pnpm format
# Type checking
pnpm types
```
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
### Backend Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
### Frontend Architecture
- **Framework**: Next.js 15 App Router (client-first approach)
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
- **State Management**: React Query for server state, co-located UI state in components/hooks
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
- **Workflow Builder**: Visual graph editor using @xyflow/react
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
- **Icons**: Phosphor Icons only
- **Feature Flags**: LaunchDarkly integration
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
- **Testing**: Playwright for E2E, Storybook for component development
### Key Concepts
## Key Concepts
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
3. **Integrations**: OAuth and API connections stored per user
4. **Store**: Marketplace for sharing agent templates
5. **Virus Scanning**: ClamAV integration for file upload security
### Testing Approach
- Backend uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
- Frontend uses Playwright for E2E tests
- Component testing via Storybook
### Database Schema
Key models (defined in `/backend/schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
### Environment Configuration
#### Configuration Files
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
#### Docker Environment Loading Order
@@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
- Backend/Frontend services use YAML anchors for consistent configuration
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
### Common Development Tasks
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
**Modifying the API:**
1. Update route in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
### Frontend guidelines:
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
- Use function declarations for components, arrow functions only for callbacks
- No barrel files or `index.ts` re-exports
- Do not use `useCallback` or `useMemo` unless strictly needed
- Avoid comments at all times unless the code is very complex
### Security Implementation
**Cache Protection Middleware:**
- Located in `/backend/backend/server/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
### Creating Pull Requests
- Create the PR aginst the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
- Use conventional commit messages (see below)/
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
- Create the PR against the `dev` branch of the repository.
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
- Use conventional commit messages (see below)
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
- Run the github pre-commit hooks to ensure code quality.
### Reviewing/Revising Pull Requests

View File

@@ -0,0 +1,170 @@
# CLAUDE.md - Backend
This file provides guidance to Claude Code when working with the backend.
## Essential Commands
To run something with Python package dependencies you MUST use `poetry run ...`.
```bash
# Install dependencies
poetry install
# Run database migrations
poetry run prisma migrate dev
# Start all services (database, redis, rabbitmq, clamav)
docker compose up -d
# Run the backend as a whole
poetry run app
# Run tests
poetry run test
# Run specific test
poetry run pytest path/to/test_file.py::test_function_name
# Run block tests (tests that validate all blocks work correctly)
poetry run pytest backend/blocks/test/test_block.py -xvs
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
# Lint and format
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
poetry run format # Black + isort
poetry run lint # ruff
```
More details can be found in @TESTING.md
### Creating/Updating Snapshots
When you first write a test or when the expected output changes:
```bash
poetry run pytest path/to/test.py --snapshot-update
```
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
## Architecture
- **API Layer**: FastAPI with REST and WebSocket endpoints
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
- **Queue System**: RabbitMQ for async task processing
- **Execution Engine**: Separate executor service processes agent workflows
- **Authentication**: JWT-based with Supabase integration
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
## Testing Approach
- Uses pytest with snapshot testing for API responses
- Test files are colocated with source files (`*_test.py`)
## Database Schema
Key models (defined in `schema.prisma`):
- `User`: Authentication and profile data
- `AgentGraph`: Workflow definitions with version control
- `AgentGraphExecution`: Execution history and results
- `AgentNode`: Individual nodes in a workflow
- `StoreListing`: Marketplace listings for sharing agents
## Environment Configuration
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
## Common Development Tasks
### Adding a new block
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
- Testing and validation
- File organization
Quick steps:
1. Create new file in `backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
4. Define input/output schemas using `BlockSchema`
5. Implement async `run` method
6. Generate unique block ID using `uuid.uuid4()`
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
ex: do the inputs and outputs tie well together?
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
#### Handling files in blocks with `store_media_file()`
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
| Format | Use When | Returns |
|--------|----------|---------|
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
**Examples:**
```python
# INPUT: Need to process file locally with ffmpeg
local_path = await store_media_file(
file=input_data.video,
execution_context=execution_context,
return_format="for_local_processing",
)
# local_path = "video.mp4" - use with Path/ffmpeg/etc
# INPUT: Need to send to external API like Replicate
image_b64 = await store_media_file(
file=input_data.image,
execution_context=execution_context,
return_format="for_external_api",
)
# image_b64 = "data:image/png;base64,iVBORw0..." - send to API
# OUTPUT: Returning result from block
result_url = await store_media_file(
file=generated_image_url,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", result_url
# In CoPilot: result_url = "workspace://abc123"
# In graphs: result_url = "data:image/png;base64,..."
```
**Key points:**
- `for_block_output` is the ONLY format that auto-adapts to execution context
- Always use `for_block_output` for block outputs unless you have a specific reason not to
- Never hardcode workspace checks - let `for_block_output` handle it
### Modifying the API
1. Update route in `backend/api/features/`
2. Add/update Pydantic models in same directory
3. Write tests alongside the route file
4. Run `poetry run test` to verify
## Security Implementation
### Cache Protection Middleware
- Located in `backend/api/middleware/security.py`
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses an allow list approach - only explicitly permitted paths can be cached
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications

View File

@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
#### Using Global Auth Fixtures
Two global auth fixtures are provided by `backend/server/conftest.py`:
Two global auth fixtures are provided by `backend/api/conftest.py`:
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")

View File

@@ -122,24 +122,6 @@ class ConnectionManager:
return len(connections)
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
"""Broadcast a message to all active websocket connections."""
message = WSMessage(
method=method,
data=data,
).model_dump_json()
connections = tuple(self.active_connections)
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -176,64 +176,30 @@ async def get_execution_analytics_config(
# Return with provider prefix for clarity
return f"{provider_name}: {model_name}"
# Get all models from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
from backend.server.v2.llm import db as llm_db
# Get the recommended model from the database (configurable via admin UI)
recommended_model_slug = await llm_db.get_recommended_model_slug()
# Build the available models list
first_enabled_slug = None
for registry_model in llm_registry.iter_dynamic_models():
# Only include enabled models in the list
if not registry_model.is_enabled:
continue
# Track first enabled model as fallback
if first_enabled_slug is None:
first_enabled_slug = registry_model.slug
model_enum = LlmModel(registry_model.slug) # Create enum instance from slug
label = generate_model_label(model_enum)
# Include all LlmModel values (no more filtering by hardcoded list)
recommended_model = LlmModel.GPT4O_MINI.value
for model in LlmModel:
label = generate_model_label(model)
# Add "(Recommended)" suffix to the recommended model
if registry_model.slug == recommended_model_slug:
if model.value == recommended_model:
label += " (Recommended)"
available_models.append(
ModelInfo(
value=registry_model.slug,
value=model.value,
label=label,
provider=registry_model.metadata.provider,
provider=model.provider,
)
)
# Sort models by provider and name for better UX
available_models.sort(key=lambda x: (x.provider, x.label))
# Handle case where no models are available
if not available_models:
logger.warning(
"No enabled LLM models found in registry. "
"Ensure models are configured and enabled in the LLM Registry."
)
# Provide a placeholder entry so admins see meaningful feedback
available_models.append(
ModelInfo(
value="",
label="No models available - configure in LLM Registry",
provider="none",
)
)
# Use the DB recommended model, or fallback to first enabled model
final_recommended = recommended_model_slug or first_enabled_slug or ""
return ExecutionAnalyticsConfig(
available_models=available_models,
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
default_user_prompt=DEFAULT_USER_PROMPT,
recommended_model=final_recommended,
recommended_model=recommended_model,
)

View File

@@ -1,595 +0,0 @@
import logging
import autogpt_libs.auth
import fastapi
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.server.v2.llm import db as llm_db
from backend.server.v2.llm import model as llm_model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
tags=["llm", "admin"],
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
)
async def _refresh_runtime_state() -> None:
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
logger.info("Refreshing LLM registry runtime state...")
try:
# Refresh registry from database
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated model options
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
logger.info("Cleared all block schema caches")
# Clear the /blocks endpoint cache so frontend gets updated schemas
try:
from backend.api.features.v1 import _get_cached_blocks
_get_cached_blocks.cache_clear()
logger.info("Cleared /blocks endpoint cache")
except Exception as e:
logger.warning("Failed to clear /blocks cache: %s", e)
# Clear the v2 builder caches (if they exist)
try:
from backend.api.features.builder import db as builder_db
if hasattr(builder_db, "_get_all_providers"):
builder_db._get_all_providers.cache_clear()
logger.info("Cleared v2 builder providers cache")
if hasattr(builder_db, "_build_cached_search_results"):
builder_db._build_cached_search_results.cache_clear()
logger.info("Cleared v2 builder search results cache")
except Exception as e:
logger.debug("Could not clear v2 builder cache: %s", e)
# Notify all executor services to refresh their registry cache
from backend.data.llm_registry import publish_registry_refresh_notification
await publish_registry_refresh_notification()
logger.info("Published registry refresh notification")
except Exception as exc:
logger.exception(
"LLM runtime state refresh failed; caches may be stale: %s", exc
)
@router.get(
"/providers",
summary="List LLM providers",
response_model=llm_model.LlmProvidersResponse,
)
async def list_llm_providers(include_models: bool = True):
providers = await llm_db.list_providers(include_models=include_models)
return llm_model.LlmProvidersResponse(providers=providers)
@router.post(
"/providers",
summary="Create LLM provider",
response_model=llm_model.LlmProvider,
)
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
provider = await llm_db.upsert_provider(request=request)
await _refresh_runtime_state()
return provider
@router.patch(
"/providers/{provider_id}",
summary="Update LLM provider",
response_model=llm_model.LlmProvider,
)
async def update_llm_provider(
provider_id: str,
request: llm_model.UpsertLlmProviderRequest,
):
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
await _refresh_runtime_state()
return provider
@router.delete(
"/providers/{provider_id}",
summary="Delete LLM provider",
response_model=dict,
)
async def delete_llm_provider(provider_id: str):
"""
Delete an LLM provider.
A provider can only be deleted if it has no associated models.
Delete all models from the provider first before deleting the provider.
"""
try:
await llm_db.delete_provider(provider_id)
await _refresh_runtime_state()
logger.info("Deleted LLM provider '%s'", provider_id)
return {"success": True, "message": "Provider deleted successfully"}
except ValueError as e:
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
raise fastapi.HTTPException(status_code=500, detail=str(e))
@router.get(
"/models",
summary="List LLM models",
response_model=llm_model.LlmModelsResponse,
)
async def list_llm_models(
provider_id: str | None = fastapi.Query(default=None),
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
page_size: int = fastapi.Query(
default=50, ge=1, le=100, description="Number of models per page"
),
):
return await llm_db.list_models(
provider_id=provider_id, page=page, page_size=page_size
)
@router.post(
"/models",
summary="Create LLM model",
response_model=llm_model.LlmModel,
)
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
model = await llm_db.create_model(request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}",
summary="Update LLM model",
response_model=llm_model.LlmModel,
)
async def update_llm_model(
model_id: str,
request: llm_model.UpdateLlmModelRequest,
):
model = await llm_db.update_model(model_id=model_id, request=request)
await _refresh_runtime_state()
return model
@router.patch(
"/models/{model_id}/toggle",
summary="Toggle LLM model availability",
response_model=llm_model.ToggleLlmModelResponse,
)
async def toggle_llm_model(
model_id: str,
request: llm_model.ToggleLlmModelRequest,
):
"""
Toggle a model's enabled status, optionally migrating workflows when disabling.
If disabling a model and `migrate_to_slug` is provided, all workflows using
this model will be migrated to the specified replacement model before disabling.
A migration record is created which can be reverted later using the revert endpoint.
Optional fields:
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
- `custom_credit_cost`: Custom pricing override for billing during migration
"""
try:
result = await llm_db.toggle_model(
model_id=model_id,
is_enabled=request.is_enabled,
migrate_to_slug=request.migrate_to_slug,
migration_reason=request.migration_reason,
custom_credit_cost=request.custom_credit_cost,
)
await _refresh_runtime_state()
if result.nodes_migrated > 0:
logger.info(
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
result.model.slug,
"enabled" if request.is_enabled else "disabled",
result.nodes_migrated,
result.migrated_to_slug,
result.migration_id,
)
return result
except ValueError as exc:
logger.warning("Model toggle validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to toggle model availability",
) from exc
@router.get(
"/models/{model_id}/usage",
summary="Get model usage count",
response_model=llm_model.LlmModelUsageResponse,
)
async def get_llm_model_usage(model_id: str):
"""Get the number of workflow nodes using this model."""
try:
return await llm_db.get_model_usage(model_id=model_id)
except ValueError as exc:
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to get model usage %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get model usage",
) from exc
@router.delete(
"/models/{model_id}",
summary="Delete LLM model and migrate workflows",
response_model=llm_model.DeleteLlmModelResponse,
)
async def delete_llm_model(
model_id: str,
replacement_model_slug: str | None = fastapi.Query(
default=None,
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
),
):
"""
Delete a model and optionally migrate workflows using it to a replacement model.
If no workflows are using this model, it can be deleted without providing a
replacement. If workflows exist, replacement_model_slug is required.
This endpoint:
1. Counts how many workflow nodes use the model being deleted
2. If nodes exist, validates the replacement model and migrates them
3. Deletes the model record
4. Refreshes all caches and notifies executors
Example: DELETE /admin/llm/models/{id}?replacement_model_slug=gpt-4o
Example (no usage): DELETE /admin/llm/models/{id}
"""
try:
result = await llm_db.delete_model(
model_id=model_id, replacement_model_slug=replacement_model_slug
)
await _refresh_runtime_state()
logger.info(
"Deleted model '%s' and migrated %d nodes to '%s'",
result.deleted_model_slug,
result.nodes_migrated,
result.replacement_model_slug,
)
return result
except ValueError as exc:
# Validation errors (model not found, replacement invalid, etc.)
logger.warning("Model deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete model and migrate workflows",
) from exc
# ============================================================================
# Migration Management Endpoints
# ============================================================================
@router.get(
"/migrations",
summary="List model migrations",
response_model=llm_model.LlmMigrationsResponse,
)
async def list_llm_migrations(
include_reverted: bool = fastapi.Query(
default=False, description="Include reverted migrations in the list"
),
):
"""
List all model migrations.
Migrations are created when disabling a model with the migrate_to_slug option.
They can be reverted to restore the original model configuration.
"""
try:
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
return llm_model.LlmMigrationsResponse(migrations=migrations)
except Exception as exc:
logger.exception("Failed to list migrations: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list migrations",
) from exc
@router.get(
"/migrations/{migration_id}",
summary="Get migration details",
response_model=llm_model.LlmModelMigration,
)
async def get_llm_migration(migration_id: str):
"""Get details of a specific migration."""
try:
migration = await llm_db.get_migration(migration_id)
if not migration:
raise fastapi.HTTPException(
status_code=404, detail=f"Migration '{migration_id}' not found"
)
return migration
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get migration",
) from exc
@router.post(
"/migrations/{migration_id}/revert",
summary="Revert a model migration",
response_model=llm_model.RevertMigrationResponse,
)
async def revert_llm_migration(
migration_id: str,
request: llm_model.RevertMigrationRequest | None = None,
):
"""
Revert a model migration, restoring affected workflows to their original model.
This only reverts the specific nodes that were part of the migration.
The source model must exist for the revert to succeed.
Options:
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
Response includes:
- `nodes_reverted`: Number of nodes successfully reverted
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
- `source_model_re_enabled`: Whether the source model was re-enabled
Requirements:
- Migration must not already be reverted
- Source model must exist
"""
try:
re_enable = request.re_enable_source_model if request else True
result = await llm_db.revert_migration(
migration_id,
re_enable_source_model=re_enable,
)
await _refresh_runtime_state()
logger.info(
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
"(%d already changed, source re-enabled=%s)",
migration_id,
result.nodes_reverted,
result.target_model_slug,
result.source_model_slug,
result.nodes_already_changed,
result.source_model_re_enabled,
)
return result
except ValueError as exc:
logger.warning("Migration revert validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to revert migration",
) from exc
# ============================================================================
# Creator Management Endpoints
# ============================================================================
@router.get(
"/creators",
summary="List model creators",
response_model=llm_model.LlmCreatorsResponse,
)
async def list_llm_creators():
"""
List all model creators.
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
This is distinct from providers who host/serve the models (e.g., OpenRouter).
"""
try:
creators = await llm_db.list_creators()
return llm_model.LlmCreatorsResponse(creators=creators)
except Exception as exc:
logger.exception("Failed to list creators: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to list creators",
) from exc
@router.get(
"/creators/{creator_id}",
summary="Get creator details",
response_model=llm_model.LlmModelCreator,
)
async def get_llm_creator(creator_id: str):
"""Get details of a specific model creator."""
try:
creator = await llm_db.get_creator(creator_id)
if not creator:
raise fastapi.HTTPException(
status_code=404, detail=f"Creator '{creator_id}' not found"
)
return creator
except fastapi.HTTPException:
raise
except Exception as exc:
logger.exception("Failed to get creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get creator",
) from exc
@router.post(
"/creators",
summary="Create model creator",
response_model=llm_model.LlmModelCreator,
)
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
"""
Create a new model creator.
A creator represents an organization that creates/trains AI models,
such as OpenAI, Anthropic, Meta, or Google.
"""
try:
creator = await llm_db.upsert_creator(request=request)
await _refresh_runtime_state()
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
return creator
except Exception as exc:
logger.exception("Failed to create creator: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to create creator",
) from exc
@router.patch(
"/creators/{creator_id}",
summary="Update model creator",
response_model=llm_model.LlmModelCreator,
)
async def update_llm_creator(
creator_id: str,
request: llm_model.UpsertLlmCreatorRequest,
):
"""Update an existing model creator."""
try:
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
await _refresh_runtime_state()
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
return creator
except Exception as exc:
logger.exception("Failed to update creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to update creator",
) from exc
@router.delete(
"/creators/{creator_id}",
summary="Delete model creator",
response_model=dict,
)
async def delete_llm_creator(creator_id: str):
"""
Delete a model creator.
This will remove the creator association from all models that reference it
(sets creatorId to NULL), but will not delete the models themselves.
"""
try:
await llm_db.delete_creator(creator_id)
await _refresh_runtime_state()
logger.info("Deleted model creator '%s'", creator_id)
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
except ValueError as exc:
logger.warning("Creator deletion validation failed: %s", exc)
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to delete creator",
) from exc
# ============================================================================
# Recommended Model Endpoints
# ============================================================================
@router.get(
"/recommended-model",
summary="Get recommended model",
response_model=llm_model.RecommendedModelResponse,
)
async def get_recommended_model():
"""
Get the currently recommended LLM model.
The recommended model is shown to users as the default/suggested option
in model selection dropdowns.
"""
try:
model = await llm_db.get_recommended_model()
return llm_model.RecommendedModelResponse(
model=model,
slug=model.slug if model else None,
)
except Exception as exc:
logger.exception("Failed to get recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to get recommended model",
) from exc
@router.post(
"/recommended-model",
summary="Set recommended model",
response_model=llm_model.SetRecommendedModelResponse,
)
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
"""
Set a model as the recommended model.
This clears the recommended flag from any other model and sets it on
the specified model. The model must be enabled to be set as recommended.
The recommended model is displayed to users as the default/suggested
option in model selection dropdowns throughout the platform.
"""
try:
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
await _refresh_runtime_state()
logger.info(
"Set recommended model to '%s' (previous: %s)",
model.slug,
previous_slug or "none",
)
return llm_model.SetRecommendedModelResponse(
model=model,
previous_recommended_slug=previous_slug,
message=f"Model '{model.display_name}' is now the recommended model",
)
except ValueError as exc:
logger.warning("Set recommended model validation failed: %s", exc)
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
logger.exception("Failed to set recommended model: %s", exc)
raise fastapi.HTTPException(
status_code=500,
detail="Failed to set recommended model",
) from exc

View File

@@ -1,491 +0,0 @@
import json
from unittest.mock import AsyncMock
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from pytest_snapshot.plugin import Snapshot
import backend.api.features.admin.llm_routes as llm_routes
from backend.server.v2.llm import model as llm_model
from backend.util.models import Pagination
app = fastapi.FastAPI()
app.include_router(llm_routes.router, prefix="/admin/llm")
client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def test_list_llm_providers_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM providers"""
# Mock the database function
mock_providers = [
{
"id": "provider-1",
"name": "openai",
"display_name": "OpenAI",
"description": "OpenAI LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
{
"id": "provider-2",
"name": "anthropic",
"display_name": "Anthropic",
"description": "Anthropic LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": True,
"metadata": {},
"models": [],
},
]
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_providers",
new=AsyncMock(return_value=mock_providers),
)
response = client.get("/admin/llm/providers")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["providers"]) == 2
assert response_data["providers"][0]["name"] == "openai"
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_providers_success.json",
)
def test_list_llm_models_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful listing of LLM models with pagination"""
# Mock the database function - now returns LlmModelsResponse
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=True,
capabilities={},
metadata={},
costs=[
llm_model.LlmModelCost(
id="cost-1",
credit_cost=10,
credential_provider="openai",
metadata={},
)
],
)
mock_response = llm_model.LlmModelsResponse(
models=[mock_model],
pagination=Pagination(
total_items=1,
total_pages=1,
current_page=1,
page_size=50,
),
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.list_models",
new=AsyncMock(return_value=mock_response),
)
response = client.get("/admin/llm/models")
assert response.status_code == 200
response_data = response.json()
assert len(response_data["models"]) == 1
assert response_data["models"][0]["slug"] == "gpt-4o"
assert response_data["pagination"]["total_items"] == 1
assert response_data["pagination"]["page_size"] == 50
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"list_llm_models_success.json",
)
def test_create_llm_provider_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM provider"""
mock_provider = {
"id": "new-provider-id",
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
new=AsyncMock(return_value=mock_provider),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"name": "groq",
"display_name": "Groq",
"description": "Groq LLM provider",
"supports_tools": True,
"supports_json_output": True,
"supports_reasoning": False,
"supports_parallel_tool": False,
"metadata": {},
}
response = client.post("/admin/llm/providers", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["name"] == "groq"
assert response_data["display_name"] == "Groq"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_provider_success.json",
)
def test_create_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful creation of LLM model"""
mock_model = {
"id": "new-model-id",
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-id",
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.create_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"slug": "gpt-4.1-mini",
"display_name": "GPT-4.1 Mini",
"description": "Latest GPT-4.1 Mini model",
"provider_id": "provider-1",
"context_window": 128000,
"max_output_tokens": 16384,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"credit_cost": 5,
"credential_provider": "openai",
"metadata": {},
}
],
}
response = client.post("/admin/llm/models", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["slug"] == "gpt-4.1-mini"
assert response_data["is_enabled"] is True
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"create_llm_model_success.json",
)
def test_update_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful update of LLM model"""
mock_model = {
"id": "model-1",
"slug": "gpt-4o",
"display_name": "GPT-4o Updated",
"description": "Updated description",
"provider_id": "provider-1",
"context_window": 256000,
"max_output_tokens": 32768,
"is_enabled": True,
"capabilities": {},
"metadata": {},
"costs": [
{
"id": "cost-1",
"credit_cost": 15,
"credential_provider": "openai",
"metadata": {},
}
],
}
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.update_model",
new=AsyncMock(return_value=mock_model),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {
"display_name": "GPT-4o Updated",
"description": "Updated description",
"context_window": 256000,
"max_output_tokens": 32768,
}
response = client.patch("/admin/llm/models/model-1", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["display_name"] == "GPT-4o Updated"
assert response_data["context_window"] == 256000
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"update_llm_model_success.json",
)
def test_toggle_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful toggling of LLM model enabled status"""
# Create a proper mock model object
mock_model = llm_model.LlmModel(
id="model-1",
slug="gpt-4o",
display_name="GPT-4o",
description="GPT-4 Optimized",
provider_id="provider-1",
context_window=128000,
max_output_tokens=16384,
is_enabled=False,
capabilities={},
metadata={},
costs=[],
)
# Create a proper ToggleLlmModelResponse
mock_response = llm_model.ToggleLlmModelResponse(
model=mock_model,
nodes_migrated=0,
migrated_to_slug=None,
migration_id=None,
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
request_data = {"is_enabled": False}
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["model"]["is_enabled"] is False
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"toggle_llm_model_success.json",
)
def test_delete_llm_model_success(
mocker: pytest_mock.MockFixture,
configured_snapshot: Snapshot,
) -> None:
"""Test successful deletion of LLM model with migration"""
# Create a proper DeleteLlmModelResponse
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="gpt-3.5-turbo",
deleted_model_display_name="GPT-3.5 Turbo",
replacement_model_slug="gpt-4o-mini",
nodes_migrated=42,
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete(
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
)
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
assert response_data["nodes_migrated"] == 42
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
# Verify refresh was called
mock_refresh.assert_called_once()
# Snapshot test the response (must be string)
configured_snapshot.assert_match(
json.dumps(response_data, indent=2, sort_keys=True),
"delete_llm_model_success.json",
)
def test_delete_llm_model_validation_error(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails with proper error when validation fails"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
)
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
assert response.status_code == 400
assert "Replacement model 'invalid' not found" in response.json()["detail"]
def test_delete_llm_model_no_replacement_with_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion fails when nodes exist but no replacement is provided"""
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(
side_effect=ValueError(
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
"Please provide a replacement_model_slug to migrate them."
)
),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 400
assert "workflow node(s) are using it" in response.json()["detail"]
def test_delete_llm_model_no_replacement_no_usage(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
mock_response = llm_model.DeleteLlmModelResponse(
deleted_model_slug="unused-model",
deleted_model_display_name="Unused Model",
replacement_model_slug=None,
nodes_migrated=0,
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
)
mocker.patch(
"backend.api.features.admin.llm_routes.llm_db.delete_model",
new=AsyncMock(return_value=mock_response),
)
mock_refresh = mocker.patch(
"backend.api.features.admin.llm_routes._refresh_runtime_state",
new=AsyncMock(),
)
response = client.delete("/admin/llm/models/model-1")
assert response.status_code == 200
response_data = response.json()
assert response_data["deleted_model_slug"] == "unused-model"
assert response_data["nodes_migrated"] == 0
assert response_data["replacement_model_slug"] is None
mock_refresh.assert_called_once()

View File

@@ -15,7 +15,6 @@ from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
from backend.data.db import query_raw_with_schema
from backend.data.llm_registry import get_all_model_slugs_for_validation
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -32,14 +31,7 @@ from .model import (
)
logger = logging.getLogger(__name__)
def _get_llm_models() -> list[str]:
"""Get LLM model names for search matching from the registry."""
return [
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
]
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
@@ -504,8 +496,8 @@ async def _get_static_counts():
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if field.annotation == LlmModel:
# Check if query matches any value in llm_models from registry
if any(query in name for name in _get_llm_models()):
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False

View File

@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
)
# Taken from backend/server/v2/store/db.py
# Taken from backend/api/features/store/db.py
def sanitize_query(query: str | None) -> str | None:
if query is None:
return query

View File

@@ -0,0 +1,368 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,344 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -33,9 +33,57 @@ class ChatConfig(BaseSettings):
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=3, description="Maximum number of agent schedules"
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
@@ -76,6 +124,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -247,3 +247,45 @@ async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
session_id: str,
tool_call_id: str,
new_content: str,
) -> bool:
"""Update the content of a tool message in chat history.
Used by background tasks to update pending operation messages with final results.
Args:
session_id: The chat session ID.
tool_call_id: The tool call ID to find the message.
new_content: The new content to set.
Returns:
True if a message was updated, False otherwise.
"""
try:
result = await PrismaChatMessage.prisma().update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={
"content": new_content,
},
)
if result == 0:
logger.warning(
f"No message found to update for session {session_id}, "
f"tool_call_id {tool_call_id}"
)
return False
return True
except Exception as e:
logger.error(
f"Failed to update tool message for session {session_id}, "
f"tool_call_id {tool_call_id}: {e}"
)
return False

View File

@@ -295,6 +295,21 @@ async def cache_chat_session(session: ChatSession) -> None:
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
Used by background tasks to ensure fresh data is loaded on next access.
This is best-effort - Redis failures are logged but don't fail the operation.
"""
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Best-effort: log but don't fail - cache will expire naturally
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)

View File

@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):

View File

@@ -1,19 +1,23 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -166,13 +188,14 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, or None if not found.
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
"""
session = await get_chat_session(session_id, user_id)
@@ -180,11 +203,28 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
# Check if there's an active stream for this session
active_stream_info = None
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
id=session.session_id,
@@ -192,6 +232,7 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -211,49 +252,112 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
subscriber_queue = None
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass # Client disconnected - background task continues
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -366,6 +470,251 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,704 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
2. Listen for live updates via blocking XREAD
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
)
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
_listener_tasks.pop(queue_id, None)
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> bool:
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
Returns:
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
}
chunk_type = chunk_data.get("type")
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
try:
return chunk_class(**chunk_data)
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
Cancels the XREAD-based listener task associated with this subscriber queue
to prevent resource leaks.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
if stored_task_id != task_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
listener_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -0,0 +1,79 @@
# CoPilot Tools - Future Ideas
## Multimodal Image Support for CoPilot
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
**Backend Solution:**
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
```python
# Before sending to LLM, scan for workspace image references
# and inject them as image content parts
# Example message transformation:
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
# TO: {"role": "assistant", "content": [
# {"type": "text", "text": "Generated image: workspace://abc123"},
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
# ]}
```
**Where to implement:**
- In the chat stream handler before calling the LLM
- Or in a message preprocessing step
- Need to fetch image from workspace, convert to base64, add as image content
**Considerations:**
- Only do this for image MIME types (image/png, image/jpeg, etc.)
- May want a size limit (don't pass 10MB images)
- Track which images were "shown" to the AI for frontend indicator
- Cost implications - vision API calls are more expensive
**Frontend Solution:**
Show visual indicator on workspace files in chat:
- If AI saw the image: normal display
- If AI didn't see it: overlay icon saying "AI can't see this image"
Requires response metadata indicating which `workspace://` refs were passed to the model.
---
## Output Post-Processing Layer for run_block
**Problem:** Many blocks produce large outputs that:
- Consume massive context (100KB base64 image = ~133KB tokens)
- Can't fit in conversation
- Break things and cause high LLM costs
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
**Benefits:**
1. **Centralized** - one place to handle all output processing
2. **Future-proof** - new blocks automatically get output processing
3. **Keeps blocks pure** - they don't need to know about context constraints
4. **Handles all large outputs** - not just images
**Processing Rules:**
- Detect base64 data URIs → save to workspace, return `workspace://` reference
- Truncate very long strings (>N chars) with truncation note
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
- Handle nested large outputs in dicts recursively
- Cap total output size
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
**Example:**
```python
def _process_outputs_for_context(
outputs: dict[str, list[Any]],
workspace_manager: WorkspaceManager,
max_string_length: int = 10000,
max_array_preview: int = 5,
) -> dict[str, list[Any]]:
"""Process block outputs to prevent context bloat."""
processed = {}
for name, values in outputs.items():
processed[name] = [_process_value(v, workspace_manager) for v in values]
return processed
```

View File

@@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
@@ -18,6 +19,12 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WriteWorkspaceFileTool,
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
@@ -28,6 +35,7 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
@@ -37,6 +45,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"view_agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Workspace tools for CoPilot file operations
"list_workspace_files": ListWorkspaceFilesTool(),
"read_workspace_file": ReadWorkspaceFileTool(),
"write_workspace_file": WriteWorkspaceFileTool(),
"delete_workspace_file": DeleteWorkspaceFileTool(),
}
# Export individual tool instances for backwards compatibility
@@ -49,6 +62,11 @@ tools: list[ChatCompletionToolParam] = [
]
def get_tool(tool_name: str) -> BaseTool | None:
"""Get a tool instance by name."""
return TOOL_REGISTRY.get(tool_name)
async def execute_tool(
tool_name: str,
parameters: dict[str, Any],
@@ -57,7 +75,7 @@ async def execute_tool(
tool_call_id: str,
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = TOOL_REGISTRY.get(tool_name)
tool = get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")

View File

@@ -2,27 +2,58 @@
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
LibraryAgentSummary,
MarketplaceAgentSummary,
customize_template,
decompose_goal,
enrich_library_agents_from_steps,
extract_search_terms_from_steps,
extract_uuids_from_text,
generate_agent,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
save_agent_to_library,
search_marketplace_agents_for_generation,
)
from .errors import get_user_message_for_error
from .service import health_check as check_external_service_health
from .service import is_external_service_configured
__all__ = [
# Core functions
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",
"LibraryAgentSummary",
"MarketplaceAgentSummary",
"check_external_service_health",
"customize_template",
"decompose_goal",
"enrich_library_agents_from_steps",
"extract_search_terms_from_steps",
"extract_uuids_from_text",
"generate_agent",
"generate_agent_patch",
"save_agent_to_library",
"get_agent_as_json",
"json_to_graph",
# Exceptions
"AgentGeneratorNotConfiguredError",
# Service
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",
"is_external_service_configured",
"check_external_service_health",
"json_to_graph",
"save_agent_to_library",
"search_marketplace_agents_for_generation",
]

View File

@@ -1,13 +1,25 @@
"""Core agent generation functions."""
import logging
import re
import uuid
from typing import Any
from typing import Any, NotRequired, TypedDict
from backend.api.features.library import db as library_db
from backend.data.graph import Graph, Link, Node, create_graph
from backend.api.features.store import db as store_db
from backend.data.graph import (
Graph,
Link,
Node,
create_graph,
get_graph,
get_graph_all_versions,
get_store_listed_graphs,
)
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
customize_template_external,
decompose_goal_external,
generate_agent_external,
generate_agent_patch_external,
@@ -16,6 +28,74 @@ from .service import (
logger = logging.getLogger(__name__)
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
class ExecutionSummary(TypedDict):
"""Summary of a single execution for quality assessment."""
status: str
correctness_score: NotRequired[float]
activity_summary: NotRequired[str]
class LibraryAgentSummary(TypedDict):
"""Summary of a library agent for sub-agent composition.
Includes recent executions to help the LLM decide whether to use this agent.
Each execution shows status, correctness_score (0-1), and activity_summary.
"""
graph_id: str
graph_version: int
name: str
description: str
input_schema: dict[str, Any]
output_schema: dict[str, Any]
recent_executions: NotRequired[list[ExecutionSummary]]
class MarketplaceAgentSummary(TypedDict):
"""Summary of a marketplace agent for sub-agent composition."""
name: str
description: str
sub_heading: str
creator: str
is_marketplace_agent: bool
class DecompositionStep(TypedDict, total=False):
"""A single step in decomposed instructions."""
description: str
action: str
block_name: str
tool: str
name: str
class DecompositionResult(TypedDict, total=False):
"""Result from decompose_goal - can be instructions, questions, or error."""
type: str
steps: list[DecompositionStep]
questions: list[dict[str, Any]]
error: str
error_type: str
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
return None
return [dict(a) for a in agents]
class AgentGeneratorNotConfiguredError(Exception):
"""Raised when the external Agent Generator service is not configured."""
@@ -36,15 +116,422 @@ def _check_service_configured() -> None:
)
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
_UUID_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def extract_uuids_from_text(text: str) -> list[str]:
"""Extract all UUID v4 strings from text.
Args:
text: Text that may contain UUIDs (e.g., user's goal description)
Returns:
List of unique UUIDs found in the text (lowercase)
"""
matches = _UUID_PATTERN.findall(text)
return list({m.lower() for m in matches})
async def get_library_agent_by_id(
user_id: str, agent_id: str
) -> LibraryAgentSummary | None:
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
This function tries multiple lookup strategies:
1. First tries to find by graph_id (AgentGraph primary key)
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
This handles both cases:
- User provides graph_id (e.g., from AgentExecutorBlock)
- User provides library agent ID (e.g., from library URL)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
LibraryAgentSummary if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except DatabaseError:
raise
except Exception as e:
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
max_results: int = 15,
) -> list[LibraryAgentSummary]:
"""Fetch user's library agents formatted for Agent Generator.
Uses search-based fetching to return relevant agents instead of all agents.
This is more scalable for users with large libraries.
Includes recent_executions list to help the LLM assess agent quality:
- Each execution has status, correctness_score (0-1), and activity_summary
- This gives the LLM concrete examples of recent performance
Args:
user_id: The user ID
search_query: Optional search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
max_results: Maximum number of agents to return (default 15)
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
)
results: list[LibraryAgentSummary] = []
for agent in response.agents:
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
continue
summary = LibraryAgentSummary(
graph_id=agent.graph_id,
graph_version=agent.graph_version,
name=agent.name,
description=agent.description,
input_schema=agent.input_schema,
output_schema=agent.output_schema,
)
if agent.recent_executions:
exec_summaries: list[ExecutionSummary] = []
for ex in agent.recent_executions:
exec_sum = ExecutionSummary(status=ex.status)
if ex.correctness_score is not None:
exec_sum["correctness_score"] = ex.correctness_score
if ex.activity_summary:
exec_sum["activity_summary"] = ex.activity_summary
exec_summaries.append(exec_sum)
summary["recent_executions"] = exec_summaries
results.append(summary)
return results
except DatabaseError:
raise
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
return []
async def search_marketplace_agents_for_generation(
search_query: str,
max_results: int = 10,
) -> list[LibraryAgentSummary]:
"""Search marketplace agents formatted for Agent Generator.
Fetches marketplace agents and their full schemas so they can be used
as sub-agents in generated workflows.
Args:
search_query: Search term to find relevant public agents
max_results: Maximum number of agents to return (default 10)
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
try:
response = await store_db.get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
)
agents_with_graphs = [
agent for agent in response.agents if agent.agent_graph_id
]
if not agents_with_graphs:
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
graph_id = agent.agent_graph_id
if graph_id and graph_id in graphs:
graph = graphs[graph_id]
results.append(
LibraryAgentSummary(
graph_id=graph.id,
graph_version=graph.version,
name=agent.agent_name,
description=agent.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
)
return results
except Exception as e:
logger.warning(f"Failed to search marketplace agents: {e}")
return []
async def get_all_relevant_agents_for_generation(
user_id: str,
search_query: str | None = None,
exclude_graph_id: str | None = None,
include_library: bool = True,
include_marketplace: bool = True,
max_library_results: int = 15,
max_marketplace_results: int = 10,
) -> list[AgentSummary]:
"""Fetch relevant agents from library and/or marketplace.
Searches both user's library and marketplace by default.
Explicitly mentioned UUIDs in the search query are always looked up.
Args:
user_id: The user ID
search_query: Search term to find relevant agents (user's goal/description)
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
include_library: Whether to search user's library (default True)
include_marketplace: Whether to also search marketplace (default True)
max_library_results: Max library agents to return (default 15)
max_marketplace_results: Max marketplace agents to return (default 10)
Returns:
List of AgentSummary with full schemas (both library and marketplace agents)
"""
agents: list[AgentSummary] = []
seen_graph_ids: set[str] = set()
if search_query:
mentioned_uuids = extract_uuids_from_text(search_query)
for graph_id in mentioned_uuids:
if graph_id == exclude_graph_id:
continue
agent = await get_library_agent_by_graph_id(user_id, graph_id)
agent_graph_id = agent.get("graph_id") if agent else None
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(agent_graph_id)
logger.debug(
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
)
if include_library:
library_agents = await get_library_agents_for_generation(
user_id=user_id,
search_query=search_query,
exclude_graph_id=exclude_graph_id,
max_results=max_library_results,
)
for agent in library_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
if include_marketplace and search_query:
marketplace_agents = await search_marketplace_agents_for_generation(
search_query=search_query,
max_results=max_marketplace_results,
)
for agent in marketplace_agents:
graph_id = agent.get("graph_id")
if graph_id and graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(graph_id)
return agents
def extract_search_terms_from_steps(
decomposition_result: DecompositionResult | dict[str, Any],
) -> list[str]:
"""Extract search terms from decomposed instruction steps.
Analyzes the decomposition result to extract relevant keywords
for additional library agent searches.
Args:
decomposition_result: Result from decompose_goal containing steps
Returns:
List of unique search terms extracted from steps
"""
search_terms: list[str] = []
if decomposition_result.get("type") != "instructions":
return search_terms
steps = decomposition_result.get("steps", [])
if not steps:
return search_terms
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
for step in steps:
for key in step_keys:
value = step.get(key) # type: ignore[union-attr]
if isinstance(value, str) and len(value) > 3:
search_terms.append(value)
seen: set[str] = set()
unique_terms: list[str] = []
for term in search_terms:
term_lower = term.lower()
if term_lower not in seen:
seen.add(term_lower)
unique_terms.append(term)
return unique_terms
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
) -> list[AgentSummary] | list[dict[str, Any]]:
"""Enrich library agents list with additional searches based on decomposed steps.
This implements two-phase search: after decomposition, we search for additional
relevant agents based on the specific steps identified.
Args:
user_id: The user ID
decomposition_result: Result from decompose_goal containing steps
existing_agents: Already fetched library agents from initial search
exclude_graph_id: Optional graph ID to exclude
include_marketplace: Whether to also search marketplace
max_additional_results: Max additional agents per search term (default 10)
Returns:
Combined list of library agents (existing + newly discovered)
"""
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
for agent in existing_agents:
agent_name = agent.get("name")
if agent_name and isinstance(agent_name, str):
existing_names.add(agent_name.lower())
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
for term in search_terms[:3]:
try:
additional_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=term,
exclude_graph_id=exclude_graph_id,
include_marketplace=include_marketplace,
max_library_results=max_additional_results,
max_marketplace_results=5,
)
for agent in additional_agents:
agent_name = agent.get("name")
if not agent_name or not isinstance(agent_name, str):
continue
agent_name_lower = agent_name.lower()
if agent_name_lower in existing_names:
continue
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id and graph_id in existing_ids:
continue
all_agents.append(agent)
existing_names.add(agent_name_lower)
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
except DatabaseError:
logger.error(f"Database error searching for agents with term '{term}'")
raise
except Exception as e:
logger.warning(
f"Failed to search for additional agents with term '{term}': {e}"
)
logger.debug(
f"Enriched library agents: {len(existing_agents)} initial + "
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
)
return all_agents
async def decompose_goal(
description: str,
context: str = "",
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
DecompositionResult with either:
- {"type": "clarifying_questions", "questions": [...]}
- {"type": "instructions", "steps": [...]}
Or None on error
@@ -54,26 +541,47 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for decompose_goal")
return await decompose_goal_external(description, context)
result = await decompose_goal_external(
description, context, _to_dict_list(library_agents)
)
return result # type: ignore[return-value]
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions)
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
# Ensure required fields
if isinstance(result, dict) and result.get("type") == "error":
return result
if "id" not in result:
result["id"] = str(uuid.uuid4())
if "version" not in result:
@@ -83,6 +591,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
pass
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
@@ -91,25 +605,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
Returns:
Graph ready for saving
Raises:
AgentJsonValidationError: If required fields are missing from nodes or links
"""
nodes = []
for n in agent_json.get("nodes", []):
for idx, n in enumerate(agent_json.get("nodes", [])):
block_id = n.get("block_id")
if not block_id:
node_id = n.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Node '{node_id}' is missing required field 'block_id'"
)
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=n["block_id"],
block_id=block_id,
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for link_data in agent_json.get("links", []):
for idx, link_data in enumerate(agent_json.get("links", [])):
source_id = link_data.get("source_id")
sink_id = link_data.get("sink_id")
source_name = link_data.get("source_name")
sink_name = link_data.get("sink_name")
missing_fields = []
if not source_id:
missing_fields.append("source_id")
if not sink_id:
missing_fields.append("sink_id")
if not source_name:
missing_fields.append("source_name")
if not sink_name:
missing_fields.append("sink_name")
if missing_fields:
link_id = link_data.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
)
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
source_id=source_id,
sink_id=sink_id,
source_name=source_name,
sink_name=sink_name,
is_static=link_data.get("is_static", False),
)
links.append(link)
@@ -130,22 +674,40 @@ def _reassign_node_ids(graph: Graph) -> None:
This is needed when creating a new version to avoid unique constraint violations.
"""
# Create mapping from old node IDs to new UUIDs
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
# Reassign node IDs
for node in graph.nodes:
node.id = id_map[node.id]
# Update link references to use new node IDs
for link in graph.links:
link.id = str(uuid.uuid4()) # Also give links new IDs
link.id = str(uuid.uuid4())
if link.source_id in id_map:
link.source_id = id_map[link.source_id]
if link.sink_id in id_map:
link.sink_id = id_map[link.sink_id]
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
"""Populate user_id in AgentExecutorBlock nodes.
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
This function fills in the actual user_id so sub-agents run with correct permissions.
Args:
agent_json: Agent JSON dict (modified in place)
user_id: User ID to set
"""
for node in agent_json.get("nodes", []):
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
input_default = node.get("input_default") or {}
if not input_default.get("user_id"):
input_default["user_id"] = user_id
node["input_default"] = input_default
logger.debug(
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
)
async def save_agent_to_library(
agent_json: dict[str, Any], user_id: str, is_update: bool = False
) -> tuple[Graph, Any]:
@@ -159,33 +721,27 @@ async def save_agent_to_library(
Returns:
Tuple of (created Graph, LibraryAgent)
"""
from backend.data.graph import get_graph_all_versions
# Populate user_id in AgentExecutorBlock nodes before conversion
_populate_agent_executor_user_ids(agent_json, user_id)
graph = json_to_graph(agent_json)
if is_update:
# For updates, keep the same graph ID but increment version
# and reassign node/link IDs to avoid conflicts
if graph.id:
existing_versions = await get_graph_all_versions(graph.id, user_id)
if existing_versions:
latest_version = max(v.version for v in existing_versions)
graph.version = latest_version + 1
# Reassign node IDs (but keep graph ID the same)
_reassign_node_ids(graph)
logger.info(f"Updating agent {graph.id} to version {graph.version}")
else:
# For new agents, always generate a fresh UUID to avoid collisions
graph.id = str(uuid.uuid4())
graph.version = 1
# Reassign all node IDs as well
_reassign_node_ids(graph)
logger.info(f"Creating new agent with ID {graph.id}")
# Save to database
created_graph = await create_graph(graph, user_id)
# Add to user's library (or update existing library agent)
library_agents = await library_db.create_library_agent(
graph=created_graph,
user_id=user_id,
@@ -196,26 +752,15 @@ async def save_agent_to_library(
return created_graph, library_agents[0]
async def get_agent_as_json(
graph_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
def graph_to_json(graph: Graph) -> dict[str, Any]:
"""Convert a Graph object to JSON format for the agent generator.
Args:
graph_id: Graph ID or library agent ID
user_id: User ID
graph: Graph object to convert
Returns:
Agent as JSON dict or None if not found
Agent as JSON dict
"""
from backend.data.graph import get_graph
# Try to get the graph (version=None gets the active version)
graph = await get_graph(graph_id, version=None, user_id=user_id)
if not graph:
return None
# Convert to JSON format
nodes = []
for node in graph.nodes:
nodes.append(
@@ -252,8 +797,41 @@ async def get_agent_as_json(
}
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
return graph_to_json(graph)
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -265,13 +843,57 @@ async def generate_agent_patch(
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, or None on error
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)
return await generate_agent_patch_external(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
)

View File

@@ -0,0 +1,95 @@
"""Error handling utilities for agent generator."""
import re
def _sanitize_error_details(details: str) -> str:
"""Sanitize error details to remove sensitive information.
Strips common patterns that could expose internal system info:
- File paths (Unix and Windows)
- Database connection strings
- URLs with credentials
- Stack trace internals
Args:
details: Raw error details string
Returns:
Sanitized error details safe for user display
"""
sanitized = re.sub(
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
)
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
sanitized = re.sub(
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
)
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
sanitized = re.sub(r", line \d+", "", sanitized)
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
return sanitized.strip()
def get_user_message_for_error(
error_type: str,
operation: str = "process the request",
llm_parse_message: str | None = None,
validation_message: str | None = None,
error_details: str | None = None,
) -> str:
"""Get a user-friendly error message based on error type.
This function maps internal error types to user-friendly messages,
providing a consistent experience across different agent operations.
Args:
error_type: The error type from the external service
(e.g., "llm_parse_error", "timeout", "rate_limit")
operation: Description of what operation failed, used in the default
message (e.g., "analyze the goal", "generate the agent")
llm_parse_message: Custom message for llm_parse_error type
validation_message: Custom message for validation_error type
error_details: Optional additional details about the error
Returns:
User-friendly error message suitable for display to the user
"""
base_message = ""
if error_type == "llm_parse_error":
base_message = (
llm_parse_message
or "The AI had trouble processing this request. Please try again."
)
elif error_type == "validation_error":
base_message = (
validation_message
or "The generated agent failed validation. "
"This usually happens when the agent structure doesn't match "
"what the platform expects. Please try simplifying your goal "
"or breaking it into smaller parts."
)
elif error_type == "patch_error":
base_message = (
"Failed to apply the changes. The modification couldn't be "
"validated. Please try a different approach or simplify the change."
)
elif error_type in ("timeout", "llm_timeout"):
base_message = (
"The request took too long to process. This can happen with "
"complex agents. Please try again or simplify your goal."
)
elif error_type in ("rate_limit", "llm_rate_limit"):
base_message = "The service is currently busy. Please try again in a moment."
else:
base_message = f"Failed to {operation}. Please try again."
if error_details:
details = _sanitize_error_details(error_details)
if len(details) > 200:
details = details[:200] + "..."
base_message += f"\n\nTechnical details: {details}"
return base_message

View File

@@ -14,6 +14,70 @@ from backend.util.settings import Settings
logger = logging.getLogger(__name__)
def _create_error_response(
error_message: str,
error_type: str = "unknown",
details: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Create a standardized error response dict.
Args:
error_message: Human-readable error message
error_type: Machine-readable error type
details: Optional additional error details
Returns:
Error dict with type="error" and error details
"""
response: dict[str, Any] = {
"type": "error",
"error": error_message,
"error_type": error_type,
}
if details:
response["details"] = details
return response
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
"""Classify an HTTP error into error_type and message.
Args:
e: The HTTP status error
Returns:
Tuple of (error_type, error_message)
"""
status = e.response.status_code
if status == 429:
return "rate_limit", f"Agent Generator rate limited: {e}"
elif status == 503:
return "service_unavailable", f"Agent Generator unavailable: {e}"
elif status == 504 or status == 408:
return "timeout", f"Agent Generator timed out: {e}"
else:
return "http_error", f"HTTP error calling Agent Generator: {e}"
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
"""Classify a request error into error_type and message.
Args:
e: The request error
Returns:
Tuple of (error_type, error_message)
"""
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return "timeout", f"Agent Generator request timed out: {e}"
elif "connect" in error_str:
return "connection_error", f"Could not connect to Agent Generator: {e}"
else:
return "request_error", f"Request error calling Agent Generator: {e}"
_client: httpx.AsyncClient | None = None
_settings: Settings | None = None
@@ -53,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
async def decompose_goal_external(
description: str, context: str = ""
description: str,
context: str = "",
library_agents: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Call the external service to decompose a goal.
Args:
description: Natural language goal description
context: Additional context (e.g., answers to previous questions)
library_agents: User's library agents available for sub-agent composition
Returns:
Dict with either:
@@ -67,15 +134,17 @@ async def decompose_goal_external(
- {"type": "instructions", "steps": [...]}
- {"type": "unachievable_goal", ...}
- {"type": "vague_goal", ...}
Or None on error
- {"type": "error", "error": "...", "error_type": "..."} on error
Or None on unexpected error
"""
client = _get_client()
# Build the request payload
payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description}
if library_agents:
payload["library_agents"] = library_agents
try:
response = await client.post("/api/decompose-description", json=payload)
@@ -83,8 +152,13 @@ async def decompose_goal_external(
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator decomposition failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Map the response to the expected format
response_type = data.get("type")
@@ -106,88 +180,162 @@ async def decompose_goal_external(
"type": "vague_goal",
"suggested_goal": data.get("suggested_goal"),
}
elif response_type == "error":
# Pass through error from the service
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
else:
logger.error(
f"Unknown response type from external service: {response_type}"
)
return None
return _create_error_response(
f"Unknown response type from Agent Generator: {response_type}",
"invalid_response",
)
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_external(
instructions: dict[str, Any]
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict or None on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
)
return _create_error_response(error_msg, error_type)
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, or None on error
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
if not data.get("success"):
logger.error(f"External service returned error: {data.get('error')}")
return None
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator patch generation failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
@@ -196,18 +344,99 @@ async def generate_agent_patch_external(
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the updated agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
logger.error(f"Request error calling external agent generator: {e}")
return None
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
logger.error(f"Unexpected error calling external agent generator: {e}")
return None
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None:

View File

@@ -1,6 +1,7 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
import re
from typing import Literal
from backend.api.features.library import db as library_db
@@ -19,6 +20,85 @@ logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
_UUID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
re.IGNORECASE,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
@@ -69,29 +149,37 @@ async def search_agents(
is_featured=False,
)
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
else:
if _is_uuid(query):
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
if agent:
agents.append(agent)
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass

View File

@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -8,13 +8,17 @@ from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -42,6 +46,10 @@ class CreateAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -91,6 +99,10 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -98,9 +110,24 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Step 1: Decompose goal into steps
library_agents = None
if user_id:
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
try:
decomposition_result = await decompose_goal(description, context)
decomposition_result = await decompose_goal(
description, context, library_agents
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -113,15 +140,31 @@ class CreateAgentTool(BaseTool):
if decomposition_result is None:
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={
"description": description[:100]
}, # Include context for debugging
details={"description": description[:100]},
session_id=session_id,
)
if decomposition_result.get("type") == "error":
error_msg = decomposition_result.get("error", "Unknown error")
error_type = decomposition_result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="analyze the goal",
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
)
return ErrorResponse(
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if LLM returned clarifying questions
if decomposition_result.get("type") == "clarifying_questions":
questions = decomposition_result.get("questions", [])
return ClarificationNeededResponse(
@@ -140,7 +183,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check for unachievable/vague goals
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
@@ -167,9 +209,27 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Step 2: Generate agent JSON (external service handles fixing and validation)
if user_id and library_agents is not None:
try:
library_agents = await enrich_library_agents_from_steps(
user_id=user_id,
decomposition_result=decomposition_result,
existing_agents=library_agents,
include_marketplace=True,
)
logger.debug(
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(decomposition_result)
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -182,11 +242,47 @@ class CreateAgentTool(BaseTool):
if agent_json is None:
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
session_id=session_id,
)
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
error_msg = agent_json.get("error", "Unknown error")
error_type = agent_json.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the agent",
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
validation_message=(
"I wasn't able to create a valid agent for this request. "
"The generated workflow had some structural issues. "
"Please try simplifying your goal or breaking it into smaller steps."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100]
}, # Include context for debugging
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
@@ -195,7 +291,6 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -210,7 +305,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Save to library
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -228,7 +322,7 @@ class CreateAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -0,0 +1,337 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@property
def name(self) -> str:
return "customize_agent"
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
),
},
"modifications": {
"type": "string",
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "modifications"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
error="invalid_agent_id_format",
session_id=session_id,
)
creator_username, agent_slug = parts
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -9,12 +9,15 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -42,6 +45,10 @@ class EditAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -98,6 +105,10 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -112,7 +123,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Step 1: Fetch current agent
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
@@ -122,14 +132,34 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Build the update request with context
library_agents = None
if user_id:
try:
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(update_request, current_agent)
result = await generate_agent_patch(
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -148,7 +178,42 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if LLM returned clarifying questions
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="generate the changes",
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
validation_message="The generated changes failed validation. Please try rephrasing your request.",
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
session_id=session_id,
)
if result.get("type") == "clarifying_questions":
questions = result.get("questions", [])
return ClarificationNeededResponse(
@@ -167,7 +232,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Result is the updated agent JSON
updated_agent = result
agent_name = updated_agent.get("name", "Updated Agent")
@@ -175,7 +239,6 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
# Step 3: Preview or save
if not save:
return AgentPreviewResponse(
message=(
@@ -191,7 +254,6 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Save to library (creates a new version)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
@@ -209,7 +271,7 @@ class EditAgentTool(BaseTool):
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/{library_agent.id}",
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)

View File

@@ -28,6 +28,18 @@ class ResponseType(str, Enum):
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Workspace response types
WORKSPACE_FILE_LIST = "workspace_file_list"
WORKSPACE_FILE_CONTENT = "workspace_file_content"
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Base response model
@@ -58,6 +70,10 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None
new_output: bool | None = None
graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase):
@@ -184,6 +200,20 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
@@ -334,3 +364,60 @@ class BlockOutputResponse(ToolResponseBase):
block_name: str
outputs: dict[str, list[Any]]
success: bool = True
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
Returned for idempotency when the same tool_call_id is requested again
while the background task is still running.
"""
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -30,6 +30,7 @@ from .models import (
ErrorResponse,
ExecutionOptions,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
@@ -273,6 +274,22 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide

View File

@@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name
assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -1,13 +1,17 @@
"""Tool for executing blocks directly."""
import logging
import uuid
from collections import defaultdict
from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
@@ -73,15 +77,22 @@ class RunBlockTool(BaseTool):
self,
user_id: str,
block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -94,14 +105,33 @@ class RunBlockTool(BaseTool):
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
effective_field_info = field_info
if field_info.discriminator and field_info.discriminator_mapping:
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
if cred.provider in effective_field_info.provider
and cred.type in effective_field_info.supported_types
),
None,
)
@@ -115,8 +145,8 @@ class RunBlockTool(BaseTool):
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
provider = next(iter(effective_field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
@@ -184,10 +214,9 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
user_id, block, input_data
)
if missing_credentials:
@@ -223,11 +252,48 @@ class RunBlockTool(BaseTool):
)
try:
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run
# This means:
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
# Create unified execution context with all required fields
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=synthetic_graph_id,
graph_exec_id=synthetic_graph_exec_id,
graph_version=1, # Versions are 1-indexed
node_id=synthetic_node_id,
node_exec_id=synthetic_node_exec_id,
# Workspace with session scoping
workspace_id=workspace.id,
session_id=session.session_id,
)
# Prepare kwargs for block execution
# Keep individual kwargs for backwards compatibility with existing blocks
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": ExecutionContext(),
"execution_context": execution_context,
# Legacy: individual kwargs for blocks not yet using execution_context
"workspace_id": workspace.id,
"graph_exec_id": synthetic_graph_exec_id,
"node_exec_id": synthetic_node_exec_id,
"node_id": synthetic_node_id,
"graph_version": 1, # Versions are 1-indexed
"graph_id": synthetic_graph_id,
}
for field_name, cred_meta in matched_credentials.items():

View File

@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data import graph as graph_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import NotFoundError
@@ -266,13 +266,14 @@ async def match_user_credentials_to_graph(
credential_requirements,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider and type
# Find first matching credential by provider, type, and scopes
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
),
None,
)
@@ -296,10 +297,17 @@ async def match_user_credentials_to_graph(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} "
f"(requires provider in {list(credential_requirements.provider)}, "
f"type in {list(credential_requirements.supported_types)})"
f"{credential_field_name} (requires {', '.join(error_parts)})"
)
logger.info(
@@ -309,6 +317,28 @@ async def match_user_credentials_to_graph(
return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""
Check if a credential has all the scopes required by the block.
For OAuth2 credentials, verifies that the credential's scopes are a superset
of the required scopes. For other credential types, returns True (no scope check).
"""
# Only OAuth2 credentials have scopes to check
if credential.type != "oauth2":
return True
# If no scopes are required, any credential matches
if not requirements.required_scopes:
return True
# Check that credential scopes are a superset of required scopes
return set(credential.scopes).issuperset(requirements.required_scopes)
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -0,0 +1,620 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

View File

@@ -21,7 +21,7 @@ from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.settings import Config
@@ -39,6 +39,7 @@ async def list_library_agents(
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
page: int = 1,
page_size: int = 50,
include_executions: bool = False,
) -> library_model.LibraryAgentResponse:
"""
Retrieves a paginated list of LibraryAgent records for a given user.
@@ -49,6 +50,9 @@ async def list_library_agents(
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
page: Current page (1-indexed).
page_size: Number of items per page.
include_executions: Whether to include execution data for status calculation.
Defaults to False for performance (UI fetches status separately).
Set to True when accurate status/metrics are needed (e.g., agent generator).
Returns:
A LibraryAgentResponse containing the list of agents and pagination details.
@@ -64,11 +68,11 @@ async def list_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise DatabaseError("Invalid pagination input")
raise InvalidInputError("Invalid pagination input")
if search_term and len(search_term.strip()) > 100:
logger.warning(f"Search term too long: {repr(search_term)}")
raise DatabaseError("Search term is too long")
raise InvalidInputError("Search term is too long")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,
@@ -76,7 +80,6 @@ async def list_library_agents(
"isArchived": False,
}
# Build search filter if applicable
if search_term:
where_clause["OR"] = [
{
@@ -93,7 +96,6 @@ async def list_library_agents(
},
]
# Determine sorting
order_by: prisma.types.LibraryAgentOrderByInput | None = None
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
@@ -105,7 +107,7 @@ async def list_library_agents(
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include=library_agent_include(
user_id, include_nodes=False, include_executions=False
user_id, include_nodes=False, include_executions=include_executions
),
order=order_by,
skip=(page - 1) * page_size,
@@ -175,7 +177,7 @@ async def list_favorite_library_agents(
if page < 1 or page_size < 1:
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
raise DatabaseError("Invalid pagination input")
raise InvalidInputError("Invalid pagination input")
where_clause: prisma.types.LibraryAgentWhereInput = {
"userId": user_id,

View File

@@ -9,6 +9,7 @@ import pydantic
from backend.data.block import BlockInput
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
from backend.util.json import loads as json_loads
from backend.util.models import Pagination
if TYPE_CHECKING:
@@ -16,10 +17,10 @@ if TYPE_CHECKING:
class LibraryAgentStatus(str, Enum):
COMPLETED = "COMPLETED" # All runs completed
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
WAITING = "WAITING" # Agent is queued or waiting to start
ERROR = "ERROR" # Agent is in an error state
COMPLETED = "COMPLETED"
HEALTHY = "HEALTHY"
WAITING = "WAITING"
ERROR = "ERROR"
class MarketplaceListingCreator(pydantic.BaseModel):
@@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel):
creator: MarketplaceListingCreator
class RecentExecution(pydantic.BaseModel):
"""Summary of a recent execution for quality assessment.
Used by the LLM to understand the agent's recent performance with specific examples
rather than just aggregate statistics.
"""
status: str
correctness_score: float | None = None
activity_summary: str | None = None
def _parse_settings(settings: dict | str | None) -> GraphSettings:
"""Parse settings from database, handling both dict and string formats."""
if settings is None:
return GraphSettings()
try:
if isinstance(settings, str):
settings = json_loads(settings)
return GraphSettings.model_validate(settings)
except Exception:
return GraphSettings()
class LibraryAgent(pydantic.BaseModel):
"""
Represents an agent in the library, including metadata for display and
@@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str # ID of user who owns/created this agent graph
owner_user_id: str
image_url: str | None
@@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel):
description: str
instructions: str | None = None
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
input_schema: dict[str, Any]
output_schema: dict[str, Any]
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
description="Input schema for credentials required by the agent",
@@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel):
)
trigger_setup_info: Optional[GraphTriggerInfo] = None
# Indicates whether there's a new output (based on recent runs)
new_output: bool
# Whether the user can access the underlying graph
execution_count: int = 0
success_rate: float | None = None
avg_correctness_score: float | None = None
recent_executions: list[RecentExecution] = pydantic.Field(
default_factory=list,
description="List of recent executions with status, score, and summary",
)
can_access_graph: bool
# Indicates if this agent is the latest version
is_latest_version: bool
# Whether the agent is marked as favorite by the user
is_favorite: bool
# Recommended schedule cron (from marketplace agents)
recommended_schedule_cron: str | None = None
# User-specific settings for this library agent
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
# Marketplace listing information if the agent has been published
marketplace_listing: Optional["MarketplaceListing"] = None
@staticmethod
@@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel):
agent_updated_at = agent.AgentGraph.updatedAt
lib_agent_updated_at = agent.updatedAt
# Compute updated_at as the latest between library agent and graph
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
@@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel):
creator_name = agent.Creator.name or "Unknown"
creator_image_url = agent.Creator.avatarUrl or ""
# Logic to calculate status and new_output
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=7
)
@@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel):
status = status_result.status
new_output = status_result.new_output
# Check if user can access the graph
can_access_graph = agent.AgentGraph.userId == agent.userId
execution_count = len(executions)
success_rate: float | None = None
avg_correctness_score: float | None = None
if execution_count > 0:
success_count = sum(
1
for e in executions
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
)
success_rate = (success_count / execution_count) * 100
# Hard-coded to True until a method to check is implemented
correctness_scores = []
for e in executions:
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
correctness_scores.append(float(score))
if correctness_scores:
avg_correctness_score = sum(correctness_scores) / len(
correctness_scores
)
recent_executions: list[RecentExecution] = []
for e in executions:
exec_score: float | None = None
exec_summary: str | None = None
if e.stats and isinstance(e.stats, dict):
score = e.stats.get("correctness_score")
if score is not None and isinstance(score, (int, float)):
exec_score = float(score)
summary = e.stats.get("activity_status")
if summary is not None and isinstance(summary, str):
exec_summary = summary
exec_status = (
e.executionStatus.value
if hasattr(e.executionStatus, "value")
else str(e.executionStatus)
)
recent_executions.append(
RecentExecution(
status=exec_status,
correctness_score=exec_score,
activity_summary=exec_summary,
)
)
can_access_graph = agent.AgentGraph.userId == agent.userId
is_latest_version = True
# Build marketplace_listing if available
marketplace_listing_data = None
if store_listing and store_listing.ActiveVersion and profile:
creator_data = MarketplaceListingCreator(
@@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel):
has_sensitive_action=graph.has_sensitive_action,
trigger_setup_info=graph.trigger_setup_info,
new_output=new_output,
execution_count=execution_count,
success_rate=success_rate,
avg_correctness_score=avg_correctness_score,
recent_executions=recent_executions,
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=GraphSettings.model_validate(agent.settings),
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
)
@@ -220,18 +283,15 @@ def _calculate_agent_status(
if not executions:
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
# Track how many times each execution status appears
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
new_output = False
for execution in executions:
# Check if there's a completed run more recent than `recent_threshold`
if execution.createdAt >= recent_threshold:
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
new_output = True
status_counts[execution.executionStatus] += 1
# Determine the final status based on counts
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:

View File

@@ -1,4 +1,3 @@
import logging
from typing import Literal, Optional
import autogpt_libs.auth as autogpt_auth_lib
@@ -6,15 +5,11 @@ from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
from prisma.enums import OnboardingStep
import backend.api.features.store.exceptions as store_exceptions
from backend.data.onboarding import complete_onboarding_step
from backend.util.exceptions import DatabaseError, NotFoundError
from .. import db as library_db
from .. import model as library_model
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/agents",
tags=["library", "private"],
@@ -26,10 +21,6 @@ router = APIRouter(
"",
summary="List Library Agents",
response_model=library_model.LibraryAgentResponse,
responses={
200: {"description": "List of library agents"},
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -53,43 +44,19 @@ async def list_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
Args:
user_id: ID of the authenticated user.
search_term: Optional search term to filter agents by name/description.
filter_by: List of filters to apply (favorites, created by user).
sort_by: List of sorting criteria (created date, updated date).
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
@router.get(
"/favorites",
summary="List Favorite Library Agents",
responses={
500: {"description": "Server error", "content": {"application/json": {}}},
},
)
async def list_favorite_library_agents(
user_id: str = Security(autogpt_auth_lib.get_user_id),
@@ -106,30 +73,12 @@ async def list_favorite_library_agents(
) -> library_model.LibraryAgentResponse:
"""
Get all favorite agents in the user's library.
Args:
user_id: ID of the authenticated user.
page: Page number to retrieve.
page_size: Number of agents per page.
Returns:
A LibraryAgentResponse containing favorite agents and pagination metadata.
Raises:
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
@router.get("/{library_agent_id}", summary="Get Library Agent")
@@ -162,10 +111,6 @@ async def get_library_agent_by_graph_id(
summary="Get Agent By Store ID",
tags=["store", "library"],
response_model=library_model.LibraryAgent | None,
responses={
200: {"description": "Library agent found"},
404: {"description": "Agent not found"},
},
)
async def get_library_agent_by_store_listing_version_id(
store_listing_version_id: str,
@@ -174,32 +119,15 @@ async def get_library_agent_by_store_listing_version_id(
"""
Get Library Agent from Store Listing Version ID.
"""
try:
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.error(f"Could not fetch library agent from store version ID: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
) from e
return await library_db.get_library_agent_by_store_version_id(
store_listing_version_id, user_id
)
@router.post(
"",
summary="Add Marketplace Agent",
status_code=status.HTTP_201_CREATED,
responses={
201: {"description": "Agent added successfully"},
404: {"description": "Store listing version not found"},
500: {"description": "Server error"},
},
)
async def add_marketplace_agent_to_library(
store_listing_version_id: str = Body(embed=True),
@@ -210,59 +138,19 @@ async def add_marketplace_agent_to_library(
) -> library_model.LibraryAgent:
"""
Add an agent from the marketplace to the user's library.
Args:
store_listing_version_id: ID of the store listing version to add.
user_id: ID of the authenticated user.
Returns:
library_model.LibraryAgent: Agent added to the library
Raises:
HTTPException(404): If the listing version is not found.
HTTPException(500): If a server/database error occurs.
"""
try:
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(
user_id, OnboardingStep.MARKETPLACE_ADD_AGENT
)
return agent
except store_exceptions.AgentNotFoundError as e:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
"to add to library"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except DatabaseError as e:
logger.error(f"Database error while adding agent to library: {e}", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Inspect DB logs for details."},
) from e
except Exception as e:
logger.error(f"Unexpected error while adding agent to library: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": str(e),
"hint": "Check server logs for more information.",
},
) from e
agent = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
if source != "onboarding":
await complete_onboarding_step(user_id, OnboardingStep.MARKETPLACE_ADD_AGENT)
return agent
@router.patch(
"/{library_agent_id}",
summary="Update Library Agent",
responses={
200: {"description": "Agent updated successfully"},
500: {"description": "Server error"},
},
)
async def update_library_agent(
library_agent_id: str,
@@ -271,52 +159,21 @@ async def update_library_agent(
) -> library_model.LibraryAgent:
"""
Update the library agent with the given fields.
Args:
library_agent_id: ID of the library agent to update.
payload: Fields to update (auto_update_version, is_favorite, etc.).
user_id: ID of the authenticated user.
Raises:
HTTPException(500): If a server/database error occurs.
"""
try:
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except DatabaseError as e:
logger.error(f"Database error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Verify DB connection."},
) from e
except Exception as e:
logger.error(f"Unexpected error while updating library agent: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"message": str(e), "hint": "Check server logs."},
) from e
return await library_db.update_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
auto_update_version=payload.auto_update_version,
graph_version=payload.graph_version,
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
)
@router.delete(
"/{library_agent_id}",
summary="Delete Library Agent",
responses={
204: {"description": "Agent deleted successfully"},
404: {"description": "Agent not found"},
500: {"description": "Server error"},
},
)
async def delete_library_agent(
library_agent_id: str,
@@ -324,28 +181,11 @@ async def delete_library_agent(
) -> Response:
"""
Soft-delete the specified library agent.
Args:
library_agent_id: ID of the library agent to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
Raises:
HTTPException(404): If the agent does not exist.
HTTPException(500): If a server/database error occurs.
"""
try:
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.post("/{library_agent_id}/fork", summary="Fork Library Agent")

View File

@@ -118,21 +118,6 @@ async def test_get_library_agents_success(
)
def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch("backend.api.features.library.db.list_library_agents")
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents?search_term=test")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
search_term="test",
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
)
@pytest.mark.asyncio
async def test_get_favorite_library_agents_success(
mocker: pytest_mock.MockFixture,
@@ -190,23 +175,6 @@ async def test_get_favorite_library_agents_success(
)
def test_get_favorite_library_agents_error(
mocker: pytest_mock.MockFixture, test_user_id: str
):
mock_db_call = mocker.patch(
"backend.api.features.library.db.list_favorite_library_agents"
)
mock_db_call.side_effect = Exception("Test error")
response = client.get("/agents/favorites")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
user_id=test_user_id,
page=1,
page_size=15,
)
def test_add_agent_to_library_success(
mocker: pytest_mock.MockFixture, test_user_id: str
):
@@ -258,19 +226,3 @@ def test_add_agent_to_library_success(
store_listing_version_id="test-version-id", user_id=test_user_id
)
mock_complete_onboarding.assert_awaited_once()
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture, test_user_id: str):
mock_db_call = mocker.patch(
"backend.api.features.library.db.add_store_agent_to_library"
)
mock_db_call.side_effect = Exception("Test error")
response = client.post(
"/agents", json={"store_listing_version_id": "test-version-id"}
)
assert response.status_code == 500
assert "detail" in response.json() # Verify error response structure
mock_db_call.assert_called_once_with(
store_listing_version_id="test-version-id", user_id=test_user_id
)

View File

@@ -112,6 +112,7 @@ async def get_store_agents(
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
agent_graph_id=agent.get("agentGraphId", ""),
)
store_agents.append(store_agent)
except Exception as e:
@@ -170,6 +171,7 @@ async def get_store_agents(
description=agent.description,
runs=agent.runs,
rating=agent.rating,
agent_graph_id=agent.agentGraphId,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)

View File

@@ -454,6 +454,7 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
total_processed = 0
total_success = 0
total_failed = 0
all_errors: dict[str, int] = {} # Aggregate errors across all content types
# Process content types in explicit order
processing_order = [
@@ -499,23 +500,12 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True)
failed = len(results) - success
# Aggregate unique errors to avoid Sentry spam
# Aggregate errors across all content types
if failed > 0:
# Group errors by type and message
error_summary: dict[str, int] = {}
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
error_summary[error_key] = error_summary.get(error_key, 0) + 1
# Log aggregated error summary
error_details = ", ".join(
f"{error} ({count}x)" for error, count in error_summary.items()
)
logger.error(
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
f"Errors: {error_details}"
)
all_errors[error_key] = all_errors.get(error_key, 0) + 1
results_by_type[content_type.value] = {
"processed": len(missing_items),
@@ -542,6 +532,13 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
"error": str(e),
}
# Log aggregated errors once at the end
if all_errors:
error_details = ", ".join(
f"{error} ({count}x)" for error, count in all_errors.items()
)
logger.error(f"Embedding backfill errors: {error_details}")
return {
"by_type": results_by_type,
"totals": {

View File

@@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items
content_ids = []
for i in range(5):
@@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"pagination test item number {i}",
searchable_text=f"{unique_term} item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query="pagination test",
query=unique_term,
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
@@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page
page2_results, total2 = await unified_hybrid_search(
query="pagination test",
query=unique_term,
content_types=[ContentType.BLOCK],
page=2,
page_size=2,

View File

@@ -600,6 +600,7 @@ async def hybrid_search(
sa.featured,
sa.is_available,
sa.updated_at,
sa."agentGraphId",
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
@@ -659,6 +660,7 @@ async def hybrid_search(
featured,
is_available,
updated_at,
"agentGraphId",
searchable_text,
semantic_score,
lexical_score,

View File

@@ -38,6 +38,7 @@ class StoreAgent(pydantic.BaseModel):
description: str
runs: int
rating: float
agent_graph_id: str
class StoreAgentsResponse(pydantic.BaseModel):

View File

@@ -26,11 +26,13 @@ def test_store_agent():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
assert agent.slug == "test-agent"
assert agent.agent_name == "Test Agent"
assert agent.runs == 50
assert agent.rating == 4.5
assert agent.agent_graph_id == "test-graph-id"
def test_store_agents_response():
@@ -46,6 +48,7 @@ def test_store_agents_response():
description="Test description",
runs=50,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=store_model.Pagination(

View File

@@ -393,7 +393,6 @@ async def get_creators(
@router.get(
"/creator/{username}",
summary="Get creator details",
operation_id="getV2GetCreatorDetails",
tags=["store", "public"],
response_model=store_model.CreatorDetails,
)

View File

@@ -82,6 +82,7 @@ def test_get_agents_featured(
description="Featured agent description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-1",
)
],
pagination=store_model.Pagination(
@@ -127,6 +128,7 @@ def test_get_agents_by_creator(
description="Creator agent description",
runs=50,
rating=4.0,
agent_graph_id="test-graph-2",
)
],
pagination=store_model.Pagination(
@@ -172,6 +174,7 @@ def test_get_agents_sorted(
description="Top agent description",
runs=1000,
rating=5.0,
agent_graph_id="test-graph-3",
)
],
pagination=store_model.Pagination(
@@ -217,6 +220,7 @@ def test_get_agents_search(
description="Specific search term description",
runs=75,
rating=4.2,
agent_graph_id="test-graph-search",
)
],
pagination=store_model.Pagination(
@@ -262,6 +266,7 @@ def test_get_agents_category(
description="Category agent description",
runs=60,
rating=4.1,
agent_graph_id="test-graph-category",
)
],
pagination=store_model.Pagination(
@@ -306,6 +311,7 @@ def test_get_agents_pagination(
description=f"Agent {i} description",
runs=i * 10,
rating=4.0,
agent_graph_id="test-graph-2",
)
for i in range(5)
],

View File

@@ -33,6 +33,7 @@ class TestCacheDeletion:
description="Test description",
runs=100,
rating=4.5,
agent_graph_id="test-graph-id",
)
],
pagination=Pagination(

View File

@@ -261,14 +261,36 @@ async def get_onboarding_agents(
return await get_recommended_agents(user_id)
class OnboardingStatusResponse(pydantic.BaseModel):
"""Response for onboarding status check."""
is_onboarding_enabled: bool
is_chat_enabled: bool
@v1_router.get(
"/onboarding/enabled",
summary="Is onboarding enabled",
tags=["onboarding", "public"],
dependencies=[Security(requires_user)],
response_model=OnboardingStatusResponse,
)
async def is_onboarding_enabled() -> bool:
return await onboarding_enabled()
async def is_onboarding_enabled(
user_id: Annotated[str, Security(get_user_id)],
) -> OnboardingStatusResponse:
# Check if chat is enabled for user
is_chat_enabled = await is_feature_enabled(Flag.CHAT, user_id, False)
# If chat is enabled, skip legacy onboarding
if is_chat_enabled:
return OnboardingStatusResponse(
is_onboarding_enabled=False,
is_chat_enabled=True,
)
return OnboardingStatusResponse(
is_onboarding_enabled=await onboarding_enabled(),
is_chat_enabled=False,
)
@v1_router.post(

View File

@@ -0,0 +1 @@
# Workspace API feature module

View File

@@ -0,0 +1,122 @@
"""
Workspace API routes for managing user file storage.
"""
import logging
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
def _sanitize_filename_for_header(filename: str) -> str:
"""
Sanitize filename for Content-Disposition header to prevent header injection.
Removes/replaces characters that could break the header or inject new headers.
Uses RFC5987 encoding for non-ASCII characters.
"""
# Remove CR, LF, and null bytes (header injection prevention)
sanitized = re.sub(r"[\r\n\x00]", "", filename)
# Escape quotes
sanitized = sanitized.replace('"', '\\"')
# For non-ASCII, use RFC5987 filename* parameter
# Check if filename has non-ASCII characters
try:
sanitized.encode("ascii")
return f'attachment; filename="{sanitized}"'
except UnicodeEncodeError:
# Use RFC5987 encoding for UTF-8 filenames
encoded = quote(sanitized, safe="")
return f"attachment; filename*=UTF-8''{encoded}"
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
dependencies=[fastapi.Security(requires_user)],
)
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
},
)
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
Handles both local storage (direct streaming) and GCS (signed URL redirect
with fallback to streaming).
"""
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
)
async def download_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> Response:
"""
Download a file by its ID.
Returns the file content directly or redirects to a signed URL for GCS.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
file = await get_workspace_file(file_id, workspace.id)
if file is None:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)

View File

@@ -18,7 +18,6 @@ from prisma.errors import PrismaError
import backend.api.features.admin.credit_admin_routes
import backend.api.features.admin.execution_analytics_routes
import backend.api.features.admin.llm_routes
import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
@@ -33,16 +32,19 @@ import backend.api.features.postmark.postmark
import backend.api.features.store.model
import backend.api.features.store.routes
import backend.api.features.v1
import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.integrations.webhooks.utils
import backend.server.v2.llm.routes as public_llm_routes
import backend.util.service
import backend.util.settings
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -55,6 +57,7 @@ from backend.util.exceptions import (
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
from backend.util.workspace_storage import shutdown_workspace_storage
from .external.fastapi_app import external_api
from .features.analytics import router as analytics_router
@@ -112,37 +115,38 @@ async def lifespan_context(app: fastapi.FastAPI):
AutoRegistry.patch_integrations()
# Refresh LLM registry before initializing blocks so blocks can use registry data
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they're regenerated with updated discriminator_mapping
from backend.data.block import BlockSchema
BlockSchema.clear_all_schema_caches()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# migrate_llm_models uses registry default model
from backend.blocks.llm import LlmModel
default_model_slug = llm_registry.get_default_model_slug()
if default_model_slug:
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
else:
logger.warning("Skipping LLM model migration: no default model available")
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
logger.warning(f"Error shutting down cloud storage handler: {e}")
try:
await shutdown_workspace_storage()
except Exception as e:
logger.warning(f"Error shutting down workspace storage: {e}")
await backend.data.db.disconnect()
@@ -317,16 +321,6 @@ app.include_router(
tags=["v2", "executions", "review"],
prefix="/api/review",
)
app.include_router(
backend.api.features.admin.llm_routes.router,
tags=["v2", "admin", "llm"],
prefix="/api/llm/admin",
)
app.include_router(
public_llm_routes.router,
tags=["v2", "llm"],
prefix="/api",
)
app.include_router(
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
)
@@ -344,6 +338,11 @@ app.include_router(
tags=["v2", "chat"],
prefix="/api/chat",
)
app.include_router(
workspace_routes.router,
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -66,50 +66,24 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
try:
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def registry_refresh_worker():
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
from backend.data.redis_client import connect_async
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
)
async for message in pubsub.listen():
if (
message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info(
"Broadcasting LLM registry refresh to all WebSocket clients"
)
await manager.broadcast_to_all(
method=WSMethod.NOTIFICATION,
data={
"type": "LLM_REGISTRY_REFRESH",
"event": "registry_updated",
},
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(
execution_worker(),
notification_worker(),
registry_refresh_worker(),
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -1,6 +1,7 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -9,7 +10,6 @@ from backend.blocks.llm import (
LlmModel,
LLMResponse,
llm_call,
llm_model_schema_extra,
)
from backend.data.block import (
BlockCategory,
@@ -50,10 +50,9 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for evaluating the condition.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
@@ -83,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("image_url", "https://replicate.delivery/generated-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: MediaFileType(
"https://replicate.delivery/generated-image.jpg"
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAABAAEDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigD//2Q=="
),
},
test_credentials=TEST_CREDENTIALS,
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
for img in input_data.images
)
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
yield "image_url", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
except Exception as e:
yield "error", str(e)

View File

@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -13,6 +14,8 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
class ImageSize(str, Enum):
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
test_output=[
(
"image_url",
"https://replicate.delivery/generated-image.webp",
# Test output is a data URI since we now store images
lambda x: x.startswith("data:image/"),
),
],
test_mock={
"_run_client": lambda *args, **kwargs: "https://replicate.delivery/generated-image.webp"
# Return a data URI directly so store_media_file doesn't need to download
"_run_client": lambda *args, **kwargs: "data:image/webp;base64,UklGRiQAAABXRUJQVlA4IBgAAAAwAQCdASoBAAEAAQAcJYgCdAEO"
},
)
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
style_text = style_map.get(style, "")
return f"{style_text} of" if style_text else ""
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
try:
url = await self.generate_image(input_data, credentials)
if url:
yield "image_url", url
# Store the generated image to the user's workspace/execution folder
stored_url = await store_media_file(
file=MediaFileType(url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "error", "Image generation returned an empty result."
except Exception as e:

View File

@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -21,7 +22,9 @@ from backend.data.model import (
)
from backend.integrations.providers import ProviderName
from backend.util.exceptions import BlockExecutionError
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
"voice": Voice.LILY,
"video_style": VisualMediaType.STOCK_VIDEOS,
},
test_output=("video_url", "https://example.com/video.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/video.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
# Use data URI to avoid HTTP requests during tests
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create a new Webhook.site URL
webhook_token, webhook_url = await self.create_webhook()
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
)
video_url = await self.wait_for_video(credentials.api_key, pid)
logger.debug(f"Video ready: {video_url}")
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIAdMakerVideoCreatorBlock(Block):
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
"https://cdn.revid.ai/uploads/1747076315114-image.png",
],
},
test_output=("video_url", "https://example.com/ad.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/ad.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
class AIScreenshotToVideoAdBlock(Block):
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
"script": "Amazing numbers!",
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
},
test_output=("video_url", "https://example.com/screenshot.mp4"),
test_output=(
"video_url",
lambda x: x.startswith(("workspace://", "data:")),
),
test_mock={
"create_webhook": lambda *args, **kwargs: (
"test_uuid",
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
"check_video_status": lambda *args, **kwargs: {
"status": "ready",
"videoUrl": "https://example.com/screenshot.mp4",
"videoUrl": "data:video/mp4;base64,AAAA",
},
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
},
test_credentials=TEST_CREDENTIALS,
)
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
):
webhook_token, webhook_url = await self.create_webhook()
payload = {
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
raise RuntimeError("Failed to create video: No project ID returned")
video_url = await self.wait_for_video(credentials.api_key, pid)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from pydantic import SecretStr
from backend.data.execution import ExecutionContext
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -17,6 +18,8 @@ from backend.sdk import (
Requests,
SchemaField,
)
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
from ._config import bannerbear
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
},
test_output=[
("success", True),
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
# Output will be a workspace ref or data URI depending on context
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
("uid", "test-uid-123"),
("status", "completed"),
],
test_mock={
# Use data URI to avoid HTTP requests during tests
"_make_api_request": lambda *args, **kwargs: {
"uid": "test-uid-123",
"status": "completed",
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
"image_url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAABAAEBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+v//Z",
}
},
test_credentials=TEST_CREDENTIALS,
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
raise Exception(error_msg)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Build the modifications array
modifications = []
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
# Synchronous request - image should be ready
yield "success", True
yield "image_url", data.get("image_url", "")
# Store the generated image to workspace for persistence
image_url = data.get("image_url", "")
if image_url:
stored_url = await store_media_file(
file=MediaFileType(image_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "image_url", stored_url
else:
yield "image_url", ""
yield "uid", data.get("uid", "")
yield "status", data.get("status", "completed")

View File

@@ -9,6 +9,7 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.type import MediaFileType, convert
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
class FileStoreBlock(Block):
class Input(BlockSchemaInput):
file_in: MediaFileType = SchemaField(
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
)
base_64: bool = SchemaField(
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
default=False,
advanced=True,
title="Produce Base64 Output",
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
class Output(BlockSchemaOutput):
file_out: MediaFileType = SchemaField(
description="The relative path to the stored file in the temporary directory."
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
)
def __init__(self):
super().__init__(
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
description="Stores the input file in the temporary directory.",
description=(
"Downloads and stores a file from a URL, data URI, or local path. "
"Use this to fetch images, documents, or other files for processing. "
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
"In graphs: outputs a data URI to pass to other blocks."
),
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
input_schema=FileStoreBlock.Input,
output_schema=FileStoreBlock.Output,
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "file_out", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.file_in,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
file: MediaFileType,
filename: str,
message_content: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> dict:
intents = discord.Intents.default()
intents.guilds = True
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
# Local file path - read from stored media file
# This would be a path from a previous block's output
stored_file = await store_media_file(
graph_exec_id=graph_exec_id,
file=file,
user_id=user_id,
return_content=True, # Get as data URI
execution_context=execution_context,
return_format="for_external_api", # Get content to send to Discord
)
# Now process as data URI
header, encoded = stored_file.split(",", 1)
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
file=input_data.file,
filename=input_data.filename,
message_content=input_data.message_content,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
)
yield "status", result.get("status", "Unknown error")

View File

@@ -17,8 +17,11 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
logger = logging.getLogger(__name__)
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_output=[
# Output will be a workspace ref or data URI depending on context
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
# Use data URI to avoid HTTP requests during tests
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
},
)
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
raise RuntimeError(f"API request failed: {str(e)}")
async def run(
self, input_data: Input, *, credentials: FalCredentials, **kwargs
self,
input_data: Input,
*,
credentials: FalCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
video_url = await self.generate_video(input_data, credentials)
yield "video_url", video_url
# Store the generated video to the user's workspace for persistence
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
except Exception as e:
error_message = str(e)
yield "error", error_message

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
"credentials": TEST_CREDENTIALS_INPUT,
},
test_output=[
("output_image", "https://replicate.com/output/edited-image.png"),
# Output will be a workspace ref or data URI depending on context
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
],
test_mock={
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
# Use data URI to avoid HTTP requests during tests
"run_model": lambda *args, **kwargs: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
result = await self.run_model(
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
prompt=input_data.prompt,
input_image_b64=(
await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.input_image,
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_external_api", # Get content for Replicate API
)
if input_data.input_image
else None
),
aspect_ratio=input_data.aspect_ratio.value,
seed=input_data.seed,
user_id=user_id,
graph_exec_id=graph_exec_id,
user_id=execution_context.user_id or "",
graph_exec_id=execution_context.graph_exec_id or "",
)
yield "output_image", result
# Store the generated image to the user's workspace for persistence
stored_url = await store_media_file(
file=result,
execution_context=execution_context,
return_format="for_block_output",
)
yield "output_image", stored_url
async def run_model(
self,

View File

@@ -21,6 +21,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
from backend.util.settings import Settings
@@ -95,8 +96,7 @@ def _make_mime_text(
async def create_mime_message(
input_data,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
) -> str:
"""Create a MIME message with attachments and return base64-encoded raw message."""
@@ -117,12 +117,12 @@ async def create_mime_message(
if input_data.attachments:
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._send_email(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", result
async def _send_email(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject or not input_data.body:
raise ValueError(
"At least one recipient, subject, and body are required for sending an email"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
sent_message = await asyncio.to_thread(
lambda: service.users()
.messages()
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._create_draft(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "result", GmailDraftResult(
id=result["id"], message_id=result["message"]["id"], status="draft_created"
)
async def _create_draft(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to or not input_data.subject:
raise ValueError(
"At least one recipient and subject are required for creating a draft"
)
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
raw_message = await create_mime_message(input_data, execution_context)
draft = await asyncio.to_thread(
lambda: service.users()
.drafts()
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
async def _build_reply_message(
service, input_data, graph_exec_id: str, user_id: str
service, input_data, execution_context: ExecutionContext
) -> tuple[str, str]:
"""
Builds a reply MIME message for Gmail threads.
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
# Handle attachments
for attach in input_data.attachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
message = await self._reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", message["id"]
yield "threadId", message.get("threadId", input_data.threadId)
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
yield "email", email
async def _reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Send the message
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
draft = await self._create_draft_reply(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "draftId", draft["id"]
yield "messageId", draft["message"]["id"]
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
yield "status", "draft_created"
async def _create_draft_reply(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
# Build the reply message using the shared helper
raw, thread_id = await _build_reply_message(
service, input_data, graph_exec_id, user_id
service, input_data, execution_context
)
# Create draft with proper thread association
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
input_data: Input,
*,
credentials: GoogleCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
service = self._build_service(credentials, **kwargs)
result = await self._forward_message(
service,
input_data,
graph_exec_id,
user_id,
execution_context,
)
yield "messageId", result["id"]
yield "threadId", result.get("threadId", "")
yield "status", "forwarded"
async def _forward_message(
self, service, input_data: Input, graph_exec_id: str, user_id: str
self, service, input_data: Input, execution_context: ExecutionContext
) -> dict:
if not input_data.to:
raise ValueError("At least one recipient is required for forwarding")
@@ -1727,12 +1717,12 @@ To: {original_to}
# Add any additional attachments
for attach in input_data.additionalAttachments:
local_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=attach,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, local_path)
assert execution_context.graph_exec_id # Validated by store_media_file
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
part = MIMEBase("application", "octet-stream")
with open(abs_path, "rb") as f:
part.set_payload(f.read())

View File

@@ -15,6 +15,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
@staticmethod
async def _prepare_files(
graph_exec_id: str,
execution_context: ExecutionContext,
files_name: str,
files: list[MediaFileType],
user_id: str,
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
"""
Prepare files for the request by storing them and reading their content.
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
(files_name, (filename, BytesIO, mime_type))
"""
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
graph_exec_id = execution_context.graph_exec_id
if graph_exec_id is None:
raise ValueError("graph_exec_id is required for file operations")
for media in files:
# Normalise to a list so we can repeat the same key
rel_path = await store_media_file(
graph_exec_id, media, user_id, return_content=False
file=media,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_path = get_exec_file_path(graph_exec_id, rel_path)
async with aiofiles.open(abs_path, "rb") as f:
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
return files_payload
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
) -> BlockOutput:
# ─── Parse/normalise body ────────────────────────────────────
body = input_data.body
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
if use_files:
files_payload = await self._prepare_files(
graph_exec_id, input_data.files_name, input_data.files, user_id
execution_context, input_data.files_name, input_data.files
)
# Enforce body format rules
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
execution_context: ExecutionContext,
credentials: HostScopedCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
# Create SendWebRequestBlock.Input from our input (removing credentials field)
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
# Use parent class run method
async for output_name, output_data in super().run(
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
base_input, execution_context=execution_context, **kwargs
):
yield output_name, output_data

View File

@@ -12,6 +12,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockType,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.mock import MockObject
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
if not input_data.value:
return
# Determine return format based on user preference
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
yield "result", await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.value,
user_id=user_id,
return_content=input_data.base_64,
execution_context=execution_context,
return_format=return_format,
)

View File

@@ -162,8 +162,16 @@ class LinearClient:
"searchTerm": team_name,
}
team_id = await self.query(query, variables)
return team_id["teams"]["nodes"][0]["id"]
result = await self.query(query, variables)
nodes = result["teams"]["nodes"]
if not nodes:
raise LinearAPIException(
f"Team '{team_name}' not found. Check the team name or key and try again.",
status_code=404,
)
return nodes[0]["id"]
except LinearAPIException as e:
raise e
@@ -240,17 +248,44 @@ class LinearClient:
except LinearAPIException as e:
raise e
async def try_search_issues(self, term: str) -> list[Issue]:
async def try_search_issues(
self,
term: str,
max_results: int = 10,
team_id: str | None = None,
) -> list[Issue]:
try:
query = """
query SearchIssues($term: String!, $includeComments: Boolean!) {
searchIssues(term: $term, includeComments: $includeComments) {
query SearchIssues(
$term: String!,
$first: Int,
$teamId: String
) {
searchIssues(
term: $term,
first: $first,
teamId: $teamId
) {
nodes {
id
identifier
title
description
priority
createdAt
state {
id
name
type
}
project {
id
name
}
assignee {
id
name
}
}
}
}
@@ -258,7 +293,8 @@ class LinearClient:
variables: dict[str, Any] = {
"term": term,
"includeComments": True,
"first": max_results,
"teamId": team_id,
}
issues = await self.query(query, variables)

View File

@@ -17,7 +17,7 @@ from ._config import (
LinearScope,
linear,
)
from .models import CreateIssueResponse, Issue
from .models import CreateIssueResponse, Issue, State
class LinearCreateIssueBlock(Block):
@@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block):
description="Linear credentials with read permissions",
required_scopes={LinearScope.READ},
)
max_results: int = SchemaField(
description="Maximum number of results to return",
default=10,
ge=1,
le=100,
)
team_name: str | None = SchemaField(
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
default=None,
)
class Output(BlockSchemaOutput):
issues: list[Issue] = SchemaField(description="List of issues")
error: str = SchemaField(description="Error message if the search failed")
def __init__(self):
super().__init__(
@@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block):
description="Searches for issues on Linear",
input_schema=self.Input,
output_schema=self.Output,
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
test_input={
"term": "Test issue",
"max_results": 10,
"team_name": None,
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
},
test_credentials=TEST_CREDENTIALS_OAUTH,
@@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block):
[
Issue(
id="abc123",
identifier="abc123",
identifier="TST-123",
title="Test issue",
description="Test description",
priority=1,
state=State(
id="state1", name="In Progress", type="started"
),
createdAt="2026-01-15T10:00:00.000Z",
)
],
)
@@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block):
"search_issues": lambda *args, **kwargs: [
Issue(
id="abc123",
identifier="abc123",
identifier="TST-123",
title="Test issue",
description="Test description",
priority=1,
state=State(id="state1", name="In Progress", type="started"),
createdAt="2026-01-15T10:00:00.000Z",
)
]
},
@@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block):
async def search_issues(
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
max_results: int = 10,
team_name: str | None = None,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
response: list[Issue] = await client.try_search_issues(term=term)
return response
# Resolve team name to ID if provided
# Raises LinearAPIException with descriptive message if team not found
team_id: str | None = None
if team_name:
team_id = await client.try_get_team_by_name(team_name=team_name)
return await client.try_search_issues(
term=term,
max_results=max_results,
team_id=team_id,
)
async def run(
self,
@@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block):
"""Execute the issue search"""
try:
issues = await self.search_issues(
credentials=credentials, term=input_data.term
credentials=credentials,
term=input_data.term,
max_results=input_data.max_results,
team_name=input_data.team_name,
)
yield "issues", issues
except LinearAPIException as e:

View File

@@ -36,12 +36,21 @@ class Project(BaseModel):
content: str | None = None
class State(BaseModel):
id: str
name: str
type: str | None = (
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
)
class Issue(BaseModel):
id: str
identifier: str
title: str
description: str | None
priority: int
state: State | None = None
project: Project | None = None
createdAt: str | None = None
comments: list[Comment] | None = None

View File

@@ -4,19 +4,17 @@ import logging
import re
import secrets
from abc import ABC
from enum import Enum
from enum import Enum, EnumMeta
from json import JSONDecodeError
from typing import Any, Iterable, List, Literal, Optional
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
from pydantic_core import CoreSchema, core_schema
from pydantic import BaseModel, SecretStr
from backend.data import llm_registry
from backend.data.block import (
Block,
BlockCategory,
@@ -24,7 +22,6 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.llm_registry import ModelMetadata
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -35,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.prompt import compress_context, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -69,123 +66,113 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
"""
Returns a CredentialsField for LLM providers.
The discriminator_mapping will be refreshed when the schema is generated
if it's empty, ensuring the LLM registry is loaded.
"""
# Get the mapping now - it may be empty initially, but will be refreshed
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
mapping = llm_registry.get_llm_discriminator_mapping()
return CredentialsField(
description="API key for the LLM provider.",
discriminator="model",
discriminator_mapping=mapping, # May be empty initially, refreshed later
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)
def llm_model_schema_extra() -> dict[str, Any]:
return {"options": llm_registry.get_llm_model_schema_options()}
class ModelMetadata(NamedTuple):
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
class LlmModelMeta(type):
"""
Metaclass for LlmModel that enables attribute-style access to dynamic models.
This allows code like `LlmModel.GPT4O` to work by converting the attribute
name to a slug format:
- GPT4O -> gpt-4o
- GPT4O_MINI -> gpt-4o-mini
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
"""
def __getattr__(cls, name: str):
# Don't intercept private/dunder attributes
if name.startswith("_"):
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
# Convert attribute name to slug format:
# 1. Lowercase: GPT4O -> gpt4o
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
slug = name.lower().replace("_", "-")
# Check for exact match in registry first (e.g., "o1" stays "o1")
registry_slugs = llm_registry.get_dynamic_model_slugs()
if slug in registry_slugs:
return cls(slug)
# If no exact match, try inserting hyphen between letter and digit
# e.g., gpt4o -> gpt-4o
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
return cls(transformed_slug)
def __iter__(cls):
"""Iterate over all models from the registry.
Yields LlmModel instances for each model in the dynamic registry.
Used by __get_pydantic_json_schema__ to build model metadata.
"""
for model in llm_registry.iter_dynamic_models():
yield cls(model.slug)
class LlmModelMeta(EnumMeta):
pass
class LlmModel(str, metaclass=LlmModelMeta):
"""
Dynamic LLM model type that accepts any model slug from the registry.
This is a string subclass (not an Enum) that allows any model slug value.
All models are managed via the LLM Registry in the database.
Usage:
model = LlmModel("gpt-4o") # Direct construction
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
model.value # Returns the slug string
model.provider # Returns the provider from registry
"""
def __new__(cls, value: str):
if isinstance(value, LlmModel):
return value
return str.__new__(cls, value)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
"""
Tell Pydantic how to validate LlmModel.
Accepts strings and converts them to LlmModel instances.
"""
return core_schema.no_info_after_validator_function(
cls, # The validator function (LlmModel constructor)
core_schema.str_schema(), # Accept string input
serialization=core_schema.to_string_ser_schema(), # Serialize as string
)
@property
def value(self) -> str:
"""Return the model slug (for compatibility with enum-style access)."""
return str(self)
@classmethod
def default(cls) -> "LlmModel":
"""
Get the default model from the registry.
Returns the recommended model if set, otherwise gpt-4o if available
and enabled, otherwise the first enabled model from the registry.
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
"""
from backend.data.llm_registry import get_default_model_slug
slug = get_default_model_slug()
if slug is None:
# Registry is empty (e.g., at module import time before DB connection).
# Fall back to gpt-4o for backward compatibility.
slug = "gpt-4o"
return cls(slug)
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
GPT41 = "gpt-4.1-2025-04-14"
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
# Groq models
LLAMA3_3_70B = "llama-3.3-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_3 = "llama3.3"
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
PERPLEXITY_SONAR = "perplexity/sonar"
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
GROK_4 = "x-ai/grok-4"
GROK_4_FAST = "x-ai/grok-4-fast"
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
KIMI_K2 = "moonshotai/kimi-k2"
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
QWEN3_CODER = "qwen/qwen3-coder"
# Llama API models
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
# v0 by Vercel models
V0_1_5_MD = "v0-1.5-md"
V0_1_5_LG = "v0-1.5-lg"
V0_1_0_MD = "v0-1.0-md"
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
@@ -193,15 +180,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
llm_model_metadata = {}
for model in cls:
model_name = model.value
# Skip disabled models - only show enabled models in the picker
if not llm_registry.is_model_enabled(model_name):
continue
# Use registry directly with None check to gracefully handle
# missing metadata during startup/import before registry is populated
metadata = llm_registry.get_llm_model_metadata(model_name)
if metadata is None:
# Skip models without metadata (registry not yet populated)
continue
metadata = model.metadata
llm_model_metadata[model_name] = {
"creator": metadata.creator_name,
"creator_name": metadata.creator_name,
@@ -217,12 +196,7 @@ class LlmModel(str, metaclass=LlmModelMeta):
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
raise ValueError(
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
)
return MODEL_METADATA[self]
@property
def provider(self) -> str:
@@ -237,11 +211,297 @@ class LlmModel(str, metaclass=LlmModelMeta):
return self.metadata.max_output_tokens
# MODEL_METADATA removed - all models now come from the database via llm_registry
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
LlmModel.O3_MINI: ModelMetadata(
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata(
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata(
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5_2: ModelMetadata(
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
),
LlmModel.GPT5_1: ModelMetadata(
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
),
LlmModel.GPT5: ModelMetadata(
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_MINI: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_NANO: ModelMetadata(
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
),
LlmModel.GPT5_CHAT: ModelMetadata(
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
),
LlmModel.GPT41: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
),
LlmModel.GPT41_MINI: ModelMetadata(
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
), # gpt-4o-mini-2024-07-18
LlmModel.GPT4O: ModelMetadata(
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
), # claude-opus-4-1-20250805
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
), # claude-4-opus-20250514
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
), # claude-4-sonnet-20250514
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
), # claude-opus-4-5-20251101
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
), # claude-sonnet-4-5-20250929
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
), # claude-haiku-4-5-20251001
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
), # claude-3-haiku-20240307
# https://docs.aimlapi.com/api-overview/model-database/text-models
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
),
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
"aiml_api",
128000,
40000,
"Llama 3.1 Nemotron 70B Instruct",
"AI/ML",
"Nvidia",
1,
),
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
),
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
),
# https://console.groq.com/docs/models
LlmModel.LLAMA3_3_70B: ModelMetadata(
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
),
LlmModel.LLAMA3_1_8B: ModelMetadata(
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
),
# https://ollama.com/library
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
),
# https://openrouter.ai/models
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
"open_router",
1050000,
8192,
"Gemini 2.5 Pro Preview 03.25",
"OpenRouter",
"Google",
2,
),
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
),
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
),
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
"open_router",
1048576,
65535,
"Gemini 2.5 Flash Lite Preview 06.17",
"OpenRouter",
"Google",
1,
),
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
"open_router",
1048576,
8192,
"Gemini 2.0 Flash Lite 001",
"OpenRouter",
"Google",
1,
),
LlmModel.MISTRAL_NEMO: ModelMetadata(
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
),
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
),
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
),
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
),
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
),
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
),
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
"open_router",
128000,
16000,
"Sonar Deep Research",
"OpenRouter",
"Perplexity",
3,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
"open_router",
131000,
4096,
"Hermes 3 Llama 3.1 405B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
"open_router",
12288,
12288,
"Hermes 3 Llama 3.1 70B",
"OpenRouter",
"Nous Research",
1,
),
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
),
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
),
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
),
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
),
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
),
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
),
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
),
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
),
LlmModel.GROK_4: ModelMetadata(
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
),
LlmModel.GROK_4_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_4_1_FAST: ModelMetadata(
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
),
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
),
LlmModel.KIMI_K2: ModelMetadata(
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
),
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
"open_router",
262144,
262144,
"Qwen 3 235B A22B Thinking 2507",
"OpenRouter",
"Qwen",
1,
),
LlmModel.QWEN3_CODER: ModelMetadata(
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
),
# Llama API models
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Scout 17B 16E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
"llama_api",
128000,
4028,
"Llama 4 Maverick 17B 128E Instruct FP8",
"Llama API",
"Meta",
1,
),
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
),
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
),
# v0 by Vercel models
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
}
# Default model constant for backward compatibility
# Uses the dynamic registry to get the default model
DEFAULT_LLM_MODEL = LlmModel.default()
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class ToolCall(BaseModel):
@@ -334,10 +594,7 @@ def get_parallel_tool_calls_param(
llm_model: LlmModel, parallel_tool_calls: bool | None
):
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
# parallel tool calls. Use regex to avoid false positives like "openai/gpt-oss".
is_o_series = re.match(r"^o\d", llm_model) is not None
if is_o_series or parallel_tool_calls is None:
if llm_model.startswith("o") or parallel_tool_calls is None:
return openai.NOT_GIVEN
return parallel_tool_calls
@@ -373,98 +630,26 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
# Get model metadata and check if enabled - with fallback support
# The model we'll actually use (may differ if original is disabled)
model_to_use = llm_model.value
# Check if model is in registry and if it's enabled
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
model_info = get_model_info(llm_model.value)
if model_info and not model_info.is_enabled:
# Model is disabled - try to find a fallback from the same provider
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
)
model_to_use = fallback.slug
# Use fallback model's metadata
provider = fallback.metadata.provider
context_window = fallback.metadata.context_window
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
else:
# No fallback available - raise error
raise ValueError(
f"LLM model '{llm_model.value}' is disabled and no fallback model "
f"from the same provider is available. Please enable the model or "
f"select a different model in the block configuration."
)
else:
# Model is enabled or not in registry (legacy/static model)
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
except ValueError:
# Model not in cache - try refreshing the registry once if we have DB access
logger.warning(f"Model {llm_model.value} not found in registry cache")
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
logger.info(
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
)
except ValueError:
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
)
# Create effective model for model-specific parameter resolution (e.g., o-series check)
# This uses the resolved model_to_use which may differ from llm_model if fallback occurred
effective_model = LlmModel(model_to_use)
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
prompt = compress_prompt(
result = await compress_context(
messages=prompt,
target_tokens=context_window // 2,
lossy_ok=True,
target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization
reserve=0, # Caller handles response token budget separately
)
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
# model_max_output already set above
model_max_output = llm_model.max_output_tokens or int(2**15)
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
@@ -475,14 +660,14 @@ async def llm_call(
response_format = None
parallel_tool_calls = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
if force_json_output:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -529,7 +714,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=model_to_use,
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -593,7 +778,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -615,7 +800,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=model_to_use,
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -637,7 +822,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -645,7 +830,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -679,7 +864,7 @@ async def llm_call(
)
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
@@ -687,7 +872,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -714,7 +899,7 @@ async def llm_call(
reasoning=reasoning,
)
elif provider == "aiml_api":
client = openai.AsyncOpenAI(
client = openai.OpenAI(
base_url="https://api.aimlapi.com/v2",
api_key=credentials.api_key.get_secret_value(),
default_headers={
@@ -724,8 +909,8 @@ async def llm_call(
},
)
completion = await client.chat.completions.create(
model=model_to_use,
completion = client.chat.completions.create(
model=llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -753,11 +938,11 @@ async def llm_call(
response_format = {"type": "json_object"}
parallel_tool_calls_param = get_parallel_tool_calls_param(
effective_model, parallel_tool_calls
llm_model, parallel_tool_calls
)
response = await client.chat.completions.create(
model=model_to_use,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -808,10 +993,9 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
force_json_output: bool = SchemaField(
title="Restrict LLM to pure JSON output",
@@ -874,7 +1058,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1240,10 +1424,9 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
sys_prompt: str = SchemaField(
@@ -1337,9 +1520,8 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for summarizing the text.",
json_schema_extra=llm_model_schema_extra(),
)
focus: str = SchemaField(
title="Focus",
@@ -1555,9 +1737,8 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for the conversation.",
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_tokens: int | None = SchemaField(
@@ -1594,7 +1775,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1657,10 +1838,9 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default_factory=LlmModel.default,
default=DEFAULT_LLM_MODEL,
description="The language model to use for generating the list.",
advanced=True,
json_schema_extra=llm_model_schema_extra(),
)
credentials: AICredentials = AICredentialsField()
max_retries: int = SchemaField(
@@ -1715,7 +1895,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

View File

@@ -1,6 +1,6 @@
import os
import tempfile
from typing import Literal, Optional
from typing import Optional
from moviepy.audio.io.AudioFileClip import AudioFileClip
from moviepy.video.fx.Loop import Loop
@@ -13,6 +13,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
self,
input_data: Input,
*,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# 1) Store the input media locally
local_media_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.media_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
assert execution_context.graph_exec_id is not None
media_abspath = get_exec_file_path(
execution_context.graph_exec_id, local_media_path
)
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
# 2) Load the clip
if input_data.is_video:
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
default=None,
ge=1,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="How to return the output video. Either a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: str = SchemaField(
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the input video locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
looped_clip = looped_clip.with_audio(clip.audio)
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# Return as data URI
# Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
description="Volume scale for the newly attached audio track (1.0 = original).",
default=1.0,
)
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
description="Return the final output as a relative path or base64 data URI.",
default="file_path",
)
class Output(BlockSchemaOutput):
video_out: MediaFileType = SchemaField(
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
self,
input_data: Input,
*,
node_exec_id: str,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
assert execution_context.graph_exec_id is not None
assert execution_context.node_exec_id is not None
graph_exec_id = execution_context.graph_exec_id
node_exec_id = execution_context.node_exec_id
# 1) Store the inputs locally
local_video_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.video_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
local_audio_path = await store_media_file(
graph_exec_id=graph_exec_id,
file=input_data.audio_in,
user_id=user_id,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
output_abspath = os.path.join(abs_temp_dir, output_filename)
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
# 5) Return either path or data URI
# 5) Return output - for_block_output returns workspace:// if available, else data URI
video_out = await store_media_file(
graph_exec_id=graph_exec_id,
file=output_filename,
user_id=user_id,
return_content=input_data.output_return_type == "data_uri",
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_out", video_out

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
@staticmethod
async def take_screenshot(
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
url: str,
viewport_width: int,
viewport_height: int,
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
return {
"image": await store_media_file(
graph_exec_id=graph_exec_id,
file=MediaFileType(
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
),
user_id=user_id,
return_content=True,
execution_context=execution_context,
return_format="for_block_output",
)
}
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
input_data: Input,
*,
credentials: APIKeyCredentials,
graph_exec_id: str,
user_id: str,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
screenshot_data = await self.take_screenshot(
credentials=credentials,
graph_exec_id=graph_exec_id,
user_id=user_id,
execution_context=execution_context,
url=input_data.url,
viewport_width=input_data.viewport_width,
viewport_height=input_data.viewport_height,

View File

@@ -226,10 +226,9 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default_factory=llm.LlmModel.default,
default=llm.DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
json_schema_extra=llm.llm_model_schema_extra(),
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(

View File

@@ -7,6 +7,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import ContributorDetails, SchemaField
from backend.util.file import get_exec_file_path, store_media_file
from backend.util.type import MediaFileType
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
import csv
from io import StringIO
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
# Determine data source - prefer file_input if provided, otherwise use contents
if input_data.file_input:
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
assert execution_context.graph_exec_id # Validated by store_media_file
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -10,13 +10,13 @@ import stagehand.main
from stagehand import Stagehand
from backend.blocks.llm import (
MODEL_METADATA,
AICredentials,
AICredentialsField,
LlmModel,
ModelMetadata,
)
from backend.blocks.stagehand._config import stagehand as stagehand_provider
from backend.data import llm_registry
from backend.sdk import (
APIKeyCredentials,
Block,
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
# Anthropic
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
@property
def provider_name(self) -> str:
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
Returns the provider name for the model in the required format for Stagehand:
provider/model_name
"""
model_metadata = self.metadata
model_metadata = MODEL_METADATA[LlmModel(self.value)]
model_name = self.value
if len(model_name.split("/")) == 1 and not self.value.startswith(
@@ -107,23 +107,19 @@ class StagehandRecommendedLlmModel(str, Enum):
@property
def provider(self) -> str:
return self.metadata.provider
return MODEL_METADATA[LlmModel(self.value)].provider
@property
def metadata(self) -> ModelMetadata:
metadata = llm_registry.get_llm_model_metadata(self.value)
if metadata:
return metadata
# Fallback to LlmModel enum if registry lookup fails
return LlmModel(self.value).metadata
return MODEL_METADATA[LlmModel(self.value)]
@property
def context_window(self) -> int:
return self.metadata.context_window
return MODEL_METADATA[LlmModel(self.value)].context_window
@property
def max_output_tokens(self) -> int | None:
return self.metadata.max_output_tokens
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
class StagehandObserveBlock(Block):
@@ -141,7 +137,7 @@ class StagehandObserveBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -186,10 +182,7 @@ class StagehandObserveBlock(Block):
**kwargs,
) -> BlockOutput:
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
with disable_signal_handling():
stagehand = Stagehand(
@@ -234,7 +227,7 @@ class StagehandActBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -286,10 +279,7 @@ class StagehandActBlock(Block):
**kwargs,
) -> BlockOutput:
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
with disable_signal_handling():
stagehand = Stagehand(
@@ -334,7 +324,7 @@ class StagehandExtractBlock(Block):
model: StagehandRecommendedLlmModel = SchemaField(
title="LLM Model",
description="LLM to use for Stagehand (provider is inferred)",
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
advanced=False,
)
model_credentials: AICredentials = AICredentialsField()
@@ -374,10 +364,7 @@ class StagehandExtractBlock(Block):
**kwargs,
) -> BlockOutput:
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
logger.info(
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
)
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
with disable_signal_handling():
stagehand = Stagehand(

View File

@@ -10,6 +10,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
@@ -17,7 +18,9 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
test_output=[
(
"video_url",
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
lambda x: x.startswith(("workspace://", "data:")),
),
],
test_mock={
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
"status": "created",
},
# Use data URI to avoid HTTP requests during tests
"get_clip_status": lambda *args, **kwargs: {
"status": "done",
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
"result_url": "data:video/mp4;base64,AAAA",
},
},
test_credentials=TEST_CREDENTIALS,
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
return response.json()
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
# Create the clip
payload = {
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
for _ in range(input_data.max_polling_attempts):
status_response = await self.get_clip_status(credentials.api_key, clip_id)
if status_response["status"] == "done":
yield "video_url", status_response["result_url"]
# Store the generated video to the user's workspace for persistence
video_url = status_response["result_url"]
stored_url = await store_media_file(
file=MediaFileType(video_url),
execution_context=execution_context,
return_format="for_block_output",
)
yield "video_url", stored_url
return
elif status_response["status"] == "error":
raise RuntimeError(

View File

@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.type import MediaFileType
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
with pytest.raises(ValueError, match="File too large"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(large_data_uri),
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)
@patch("backend.util.file.Path")
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
# Should raise an error when directory size exceeds limit
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
await store_media_file(
graph_exec_id="test",
file=MediaFileType(
"data:text/plain;base64,dGVzdA=="
), # Small test file
user_id="test_user",
execution_context=ExecutionContext(
user_id="test_user",
graph_exec_id="test",
),
return_format="for_local_processing",
)

View File

@@ -11,10 +11,22 @@ from backend.blocks.http import (
HttpMethod,
SendAuthenticatedWebRequestBlock,
)
from backend.data.execution import ExecutionContext
from backend.data.model import HostScopedCredentials
from backend.util.request import Response
def make_test_context(
graph_exec_id: str = "test-exec-id",
user_id: str = "test-user-id",
) -> ExecutionContext:
"""Helper to create test ExecutionContext."""
return ExecutionContext(
user_id=user_id,
graph_exec_id=graph_exec_id,
)
class TestHttpBlockWithHostScopedCredentials:
"""Test suite for HTTP block integration with HostScopedCredentials."""
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=wildcard_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=non_matching_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=exact_match_credentials,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=auto_discovered_creds, # Execution manager found these
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=multi_header_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
async for output_name, output_data in http_block.run(
input_data,
credentials=test_creds,
graph_exec_id="test-exec-id",
user_id="test-user-id",
execution_context=make_test_context(),
):
result.append((output_name, output_data))

View File

@@ -11,6 +11,7 @@ from backend.data.block import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util import json, text
from backend.util.file import get_exec_file_path, store_media_file
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
)
async def run(
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
) -> BlockOutput:
# Store the media file properly (handles URLs, data URIs, etc.)
stored_file_path = await store_media_file(
user_id=user_id,
graph_exec_id=graph_exec_id,
file=input_data.file_input,
return_content=False,
execution_context=execution_context,
return_format="for_local_processing",
)
# Get full file path
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
# Get full file path (graph_exec_id validated by store_media_file above)
if not execution_context.graph_exec_id:
raise ValueError("execution_context.graph_exec_id is required")
file_path = get_exec_file_path(
execution_context.graph_exec_id, stored_file_path
)
if not Path(file_path).exists():
raise ValueError(f"File does not exist: {file_path}")

View File

@@ -25,7 +25,6 @@ from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
@@ -144,59 +143,35 @@ class BlockInfo(BaseModel):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
@classmethod
def clear_schema_cache(cls) -> None:
"""Clear the cached JSON schema for this class."""
# Use None instead of {} because {} is truthy and would prevent regeneration
cls.cached_jsonschema = None # type: ignore
@staticmethod
def clear_all_schema_caches() -> None:
"""Clear cached JSON schemas for all BlockSchema subclasses."""
def clear_recursive(cls: type) -> None:
"""Recursively clear cache for class and all subclasses."""
if hasattr(cls, "clear_schema_cache"):
cls.clear_schema_cache()
for subclass in cls.__subclasses__():
clear_recursive(subclass)
clear_recursive(BlockSchema)
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
# Generate schema if not cached
if not cls.cached_jsonschema:
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
if cls.cached_jsonschema:
return cls.cached_jsonschema
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next(
(k for k in keys if k in obj and len(obj[k]) == 1), None
)
if one_key:
obj.update(obj[one_key][0])
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
def ref_to_dict(obj):
if isinstance(obj, dict):
# OpenAPI <3.1 does not support sibling fields that has a $ref key
# So sometimes, the schema has an "allOf"/"anyOf"/"oneOf" with 1 item.
keys = {"allOf", "anyOf", "oneOf"}
one_key = next((k for k in keys if k in obj and len(obj[k]) == 1), None)
if one_key:
obj.update(obj[one_key][0])
return obj
return {
key: ref_to_dict(value)
for key, value in obj.items()
if not key.startswith("$") and key != one_key
}
elif isinstance(obj, list):
return [ref_to_dict(item) for item in obj]
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return obj
# Always post-process to ensure LLM registry data is up-to-date
# This refreshes model options and discriminator mappings even if schema was cached
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
return cls.cached_jsonschema
@@ -259,7 +234,7 @@ class BlockSchema(BaseModel):
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = None
cls.cached_jsonschema = {}
credentials_fields = cls.get_credentials_fields()
@@ -898,36 +873,13 @@ def is_block_auth_configured(
async def initialize_blocks() -> None:
# Refresh LLM registry before initializing blocks so blocks can use registry data
# This ensures the registry cache is populated even in executor context
try:
from backend.data import llm_registry
from backend.data.block_cost_config import refresh_llm_costs
# Only refresh if we have DB access (check if Prisma is connected)
from backend.data.db import is_connected
if is_connected():
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("LLM registry refreshed during block initialization")
else:
logger.warning(
"Prisma not connected, skipping LLM registry refresh during block initialization"
)
except Exception as exc:
logger.warning(
"Failed to refresh LLM registry during block initialization: %s", exc
)
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry
sync_all_provider_costs()
for cls in get_blocks().values():
block = cls()
@func_retry
async def sync_block_to_db(block: Block) -> None:
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -940,7 +892,7 @@ async def initialize_blocks() -> None:
outputSchema=json.dumps(block.output_schema.jsonschema()),
)
)
continue
return
input_schema = json.dumps(block.input_schema.jsonschema())
output_schema = json.dumps(block.output_schema.jsonschema())
@@ -960,6 +912,25 @@ async def initialize_blocks() -> None:
},
)
failed_blocks: list[str] = []
for cls in get_blocks().values():
block = cls()
try:
await sync_block_to_db(block)
except Exception as e:
logger.warning(
f"Failed to sync block {block.name} to database: {e}. "
"Block is still available in memory.",
exc_info=True,
)
failed_blocks.append(block.name)
if failed_blocks:
logger.error(
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> AnyBlockSchema | None:

View File

@@ -1,4 +1,3 @@
import logging
from typing import Type
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
@@ -24,18 +23,19 @@ from backend.blocks.ideogram import IdeogramModelBlock
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
from backend.blocks.llm import (
MODEL_METADATA,
AIConversationBlock,
AIListGeneratorBlock,
AIStructuredResponseGeneratorBlock,
AITextGeneratorBlock,
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
from backend.data import llm_registry
from backend.data.block import Block, BlockCost, BlockCostType
from backend.integrations.credentials_store import (
aiml_api_credentials,
@@ -55,63 +55,209 @@ from backend.integrations.credentials_store import (
v0_credentials,
)
logger = logging.getLogger(__name__)
# =============== Configure the cost for each LLM Model call =============== #
PROVIDER_CREDENTIALS = {
"openai": openai_credentials,
"anthropic": anthropic_credentials,
"groq": groq_credentials,
"open_router": open_router_credentials,
"llama_api": llama_api_credentials,
"aiml_api": aiml_api_credentials,
"v0": v0_credentials,
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
LlmModel.GPT41: 2,
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_OPUS: 14,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.OPENAI_GPT_OSS_120B: 1,
LlmModel.OPENAI_GPT_OSS_20B: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
LlmModel.GEMINI_2_5_FLASH: 1,
LlmModel.GEMINI_2_0_FLASH: 1,
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
LlmModel.DEEPSEEK_CHAT: 2,
LlmModel.DEEPSEEK_R1_0528: 1,
LlmModel.PERPLEXITY_SONAR: 1,
LlmModel.PERPLEXITY_SONAR_PRO: 5,
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
LlmModel.AMAZON_NOVA_LITE_V1: 1,
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
LlmModel.AMAZON_NOVA_PRO_V1: 1,
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
LlmModel.META_LLAMA_4_SCOUT: 1,
LlmModel.META_LLAMA_4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
LlmModel.GROK_4: 9,
LlmModel.GROK_4_FAST: 1,
LlmModel.GROK_4_1_FAST: 1,
LlmModel.GROK_CODE_FAST_1: 1,
LlmModel.KIMI_K2: 1,
LlmModel.QWEN3_235B_A22B_THINKING: 1,
LlmModel.QWEN3_CODER: 9,
# v0 by Vercel models
LlmModel.V0_1_5_MD: 1,
LlmModel.V0_1_5_LG: 2,
LlmModel.V0_1_0_MD: 1,
}
# =============== Configure the cost for each LLM Model call =============== #
# All LLM costs now come from the database via llm_registry
LLM_COST: list[BlockCost] = []
for model in LlmModel:
if model not in MODEL_COST:
raise ValueError(f"Missing MODEL_COST for model: {model}")
def _build_llm_costs_from_registry() -> list[BlockCost]:
"""Build BlockCost list from all models in the LLM registry."""
costs: list[BlockCost] = []
for model in llm_registry.iter_dynamic_models():
for cost in model.costs:
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
if not credentials:
logger.warning(
"Skipping cost entry for %s due to unknown credentials provider %s",
model.slug,
cost.credential_provider,
)
continue
cost_filter = {
"model": model.slug,
LLM_COST = (
# Anthropic Models
[
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": credentials.id,
"provider": credentials.provider,
"type": credentials.type,
"id": anthropic_credentials.id,
"provider": anthropic_credentials.provider,
"type": anthropic_credentials.type,
},
}
costs.append(
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter=cost_filter,
cost_amount=cost.credit_cost,
)
)
return costs
def refresh_llm_costs() -> None:
"""Refresh LLM costs from the registry. All costs now come from the database."""
LLM_COST.clear()
LLM_COST.extend(_build_llm_costs_from_registry())
# Initial load will happen after registry is refreshed at startup
# Don't call refresh_llm_costs() here - it will be called after registry refresh
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "anthropic"
]
# OpenAI Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
"type": openai_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "openai"
]
# Groq Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {"id": groq_credentials.id},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "groq"
]
# Open Router Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": open_router_credentials.id,
"provider": open_router_credentials.provider,
"type": open_router_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "open_router"
]
# Llama API Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": llama_api_credentials.id,
"provider": llama_api_credentials.provider,
"type": llama_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "llama_api"
]
# v0 by Vercel Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": v0_credentials.id,
"provider": v0_credentials.provider,
"type": v0_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "v0"
]
# AI/ML Api Models
+ [
BlockCost(
cost_type=BlockCostType.RUN,
cost_filter={
"model": model,
"credentials": {
"id": aiml_api_credentials.id,
"provider": aiml_api_credentials.provider,
"type": aiml_api_credentials.type,
},
},
cost_amount=cost,
)
for model, cost in MODEL_COST.items()
if MODEL_METADATA[model].provider == "aiml_api"
]
)
# =============== This is the exhaustive list of cost for each Block =============== #

View File

@@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
def __init__(self):
self._pubsub: AsyncPubSub | None = None
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()
async def close(self) -> None:
"""Close the PubSub connection if it exists."""
if self._pubsub is not None:
try:
await self._pubsub.close()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
finally:
self._pubsub = None
async def publish_event(self, event: M, channel_key: str):
"""
Publish an event to Redis. Gracefully handles connection failures
@@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
await self.connection, channel_key
)
assert isinstance(pubsub, AsyncPubSub)
self._pubsub = pubsub
if "*" in channel_key:
await pubsub.psubscribe(full_channel_name)

View File

@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
model_config = {"extra": "ignore"}
# Execution identity
user_id: Optional[str] = None
graph_id: Optional[str] = None
graph_exec_id: Optional[str] = None
graph_version: Optional[int] = None
node_id: Optional[str] = None
node_exec_id: Optional[str] = None
# Safety settings
human_in_the_loop_safe_mode: bool = True
sensitive_action_safe_mode: bool = False
# User settings
user_timezone: str = "UTC"
# Execution hierarchy
root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None
# Workspace
workspace_id: Optional[str] = None
session_id: Optional[str] = None
# -------------------------- Models -------------------------- #

View File

@@ -1028,6 +1028,39 @@ async def get_graph(
return GraphModel.from_db(graph, for_export)
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
"""Batch-fetch multiple store-listed graphs by their IDs.
Only returns graphs that have approved store listings (publicly available).
Does not require permission checks since store-listed graphs are public.
Args:
*graph_ids: Variable number of graph IDs to fetch
Returns:
Dict mapping graph_id to GraphModel for graphs with approved store listings
"""
if not graph_ids:
return {}
store_listings = await StoreListingVersion.prisma().find_many(
where={
"agentGraphId": {"in": list(graph_ids)},
"submissionStatus": SubmissionStatus.APPROVED,
"isDeleted": False,
},
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
distinct=["agentGraphId"],
order={"agentGraphVersion": "desc"},
)
return {
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
for listing in store_listings
if listing.AgentGraph
}
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,
@@ -1511,10 +1544,8 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Get all model slugs from the registry (dynamic, not hardcoded enum)
from backend.data import llm_registry
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block

View File

@@ -1,72 +0,0 @@
"""
LLM Registry module for managing LLM models, providers, and costs dynamically.
This module provides a database-driven registry system for LLM models,
replacing hardcoded model configurations with a flexible admin-managed system.
"""
from backend.data.llm_registry.model import ModelMetadata
# Re-export for backwards compatibility
from backend.data.llm_registry.notifications import (
REGISTRY_REFRESH_CHANNEL,
publish_registry_refresh_notification,
subscribe_to_registry_refresh,
)
from backend.data.llm_registry.registry import (
RegistryModel,
RegistryModelCost,
RegistryModelCreator,
get_all_model_slugs_for_validation,
get_default_model_slug,
get_dynamic_model_slugs,
get_fallback_model_for_disabled,
get_llm_discriminator_mapping,
get_llm_model_cost,
get_llm_model_metadata,
get_llm_model_schema_options,
get_model_info,
is_model_enabled,
iter_dynamic_models,
refresh_llm_registry,
register_static_costs,
register_static_metadata,
)
from backend.data.llm_registry.schema_utils import (
is_llm_model_field,
refresh_llm_discriminator_mapping,
refresh_llm_model_options,
update_schema_with_llm_registry,
)
__all__ = [
# Types
"ModelMetadata",
"RegistryModel",
"RegistryModelCost",
"RegistryModelCreator",
# Registry functions
"get_all_model_slugs_for_validation",
"get_default_model_slug",
"get_dynamic_model_slugs",
"get_fallback_model_for_disabled",
"get_llm_discriminator_mapping",
"get_llm_model_cost",
"get_llm_model_metadata",
"get_llm_model_schema_options",
"get_model_info",
"is_model_enabled",
"iter_dynamic_models",
"refresh_llm_registry",
"register_static_costs",
"register_static_metadata",
# Notifications
"REGISTRY_REFRESH_CHANNEL",
"publish_registry_refresh_notification",
"subscribe_to_registry_refresh",
# Schema utilities
"is_llm_model_field",
"refresh_llm_discriminator_mapping",
"refresh_llm_model_options",
"update_schema_with_llm_registry",
]

View File

@@ -1,25 +0,0 @@
"""Type definitions for LLM model metadata."""
from typing import Literal, NamedTuple
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model.
Attributes:
provider: The provider identifier (e.g., "openai", "anthropic")
context_window: Maximum context window size in tokens
max_output_tokens: Maximum output tokens (None if unlimited)
display_name: Human-readable name for the model
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
creator_name: Name of the organization that created the model
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
"""
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]

View File

@@ -1,89 +0,0 @@
"""
Redis pub/sub notifications for LLM registry updates.
When models are added/updated/removed via the admin UI, this module
publishes notifications to Redis that all executor services subscribe to,
ensuring they refresh their registry cache in real-time.
"""
import asyncio
import logging
from typing import Any
from backend.data.redis_client import connect_async
logger = logging.getLogger(__name__)
# Redis channel name for LLM registry refresh notifications
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
async def publish_registry_refresh_notification() -> None:
"""
Publish a notification to Redis that the LLM registry has been updated.
All executor services subscribed to this channel will refresh their registry.
"""
try:
redis = await connect_async()
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
logger.info("Published LLM registry refresh notification to Redis")
except Exception as exc:
logger.warning(
"Failed to publish LLM registry refresh notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_refresh(
on_refresh: Any, # Async callable that takes no args
) -> None:
"""
Subscribe to Redis notifications for LLM registry updates.
This runs in a loop and processes messages as they arrive.
Args:
on_refresh: Async callable to execute when a refresh notification is received
"""
try:
redis = await connect_async()
pubsub = redis.pubsub()
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
logger.info(
"Subscribed to LLM registry refresh notifications on channel: %s",
REGISTRY_REFRESH_CHANNEL,
)
# Process messages in a loop
while True:
try:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if (
message
and message["type"] == "message"
and message["channel"] == REGISTRY_REFRESH_CHANNEL
):
logger.info("Received LLM registry refresh notification")
try:
await on_refresh()
except Exception as exc:
logger.error(
"Error refreshing LLM registry from notification: %s",
exc,
exc_info=True,
)
except Exception as exc:
logger.warning(
"Error processing registry refresh message: %s", exc, exc_info=True
)
# Continue listening even if one message fails
await asyncio.sleep(1)
except Exception as exc:
logger.error(
"Failed to subscribe to LLM registry refresh notifications: %s",
exc,
exc_info=True,
)
raise

View File

@@ -1,388 +0,0 @@
"""Core LLM registry implementation for managing models dynamically."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Iterable
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
credit_cost: int
credential_provider: str
credential_id: str | None
credential_type: str | None
currency: str | None
metadata: dict[str, Any]
@dataclass(frozen=True)
class RegistryModelCreator:
"""Creator information for an LLM model."""
id: str
name: str
display_name: str
description: str | None
website_url: str | None
logo_url: str | None
@dataclass(frozen=True)
class RegistryModel:
"""Represents a model in the LLM registry."""
slug: str
display_name: str
description: str | None
metadata: ModelMetadata
capabilities: dict[str, Any]
extra_metadata: dict[str, Any]
provider_display_name: str
is_enabled: bool
is_recommended: bool = False
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
creator: RegistryModelCreator | None = None
_static_metadata: dict[str, ModelMetadata] = {}
_static_costs: dict[str, int] = {}
_dynamic_models: dict[str, RegistryModel] = {}
_schema_options: list[dict[str, str]] = []
_discriminator_mapping: dict[str, str] = {}
_lock = asyncio.Lock()
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
"""Register static metadata for legacy models (deprecated)."""
_static_metadata.update({str(key): value for key, value in metadata.items()})
_refresh_cached_schema()
def register_static_costs(costs: dict[Any, int]) -> None:
"""Register static costs for legacy models (deprecated)."""
_static_costs.update({str(key): value for key, value in costs.items()})
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
for slug, metadata in _static_metadata.items():
if slug in _dynamic_models:
continue
options.append(
{
"label": slug,
"value": slug,
"group": metadata.provider,
"description": "",
}
)
return options
async def refresh_llm_registry() -> None:
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
async with _lock:
try:
records = await prisma.models.LlmModel.prisma().find_many(
include={
"Provider": True,
"Costs": True,
"Creator": True,
}
)
logger.debug("Found %d LLM model records in database", len(records))
except Exception as exc:
logger.error(
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
)
return
dynamic: dict[str, RegistryModel] = {}
for record in records:
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display_name = (
record.Provider.displayName if record.Provider else record.providerId
)
# Creator name: prefer Creator.name, fallback to provider display name
creator_name = (
record.Creator.name if record.Creator else provider_display_name
)
# Price tier: default to 1 (cheapest) if not set
price_tier = getattr(record, "priceTier", 1) or 1
# Clamp to valid range 1-3
price_tier = max(1, min(3, price_tier))
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
display_name=record.displayName,
provider_name=provider_display_name,
creator_name=creator_name,
price_tier=price_tier, # type: ignore[arg-type]
)
costs = tuple(
RegistryModelCost(
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
)
for cost in (record.Costs or [])
)
# Map creator if present
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
dynamic[record.slug] = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=_json_to_dict(record.capabilities),
extra_metadata=_json_to_dict(record.metadata),
provider_display_name=(
record.Provider.displayName
if record.Provider
else record.providerId
),
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
# Atomic swap - build new structures then replace references
# This ensures readers never see partially updated state
global _dynamic_models
_dynamic_models = dynamic
_refresh_cached_schema()
logger.info(
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
len(dynamic),
sum(1 for m in dynamic.values() if m.is_enabled),
sum(1 for m in dynamic.values() if not m.is_enabled),
)
def _refresh_cached_schema() -> None:
"""Refresh cached schema options and discriminator mapping."""
global _schema_options, _discriminator_mapping
# Build new structures
new_options = _build_schema_options()
new_mapping = {
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
}
for slug, metadata in _static_metadata.items():
new_mapping.setdefault(slug, metadata.provider)
# Atomic swap - replace references to ensure readers see consistent state
_schema_options = new_options
_discriminator_mapping = new_mapping
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
if slug in _dynamic_models:
return _dynamic_models[slug].metadata
return _static_metadata.get(slug)
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
"""Get model cost configuration by slug."""
if slug in _dynamic_models:
return _dynamic_models[slug].costs
cost_value = _static_costs.get(slug)
if cost_value is None:
return tuple()
return (
RegistryModelCost(
credit_cost=cost_value,
credential_provider="static",
credential_id=None,
credential_type=None,
currency=None,
metadata={},
),
)
def get_llm_model_schema_options() -> list[dict[str, str]]:
"""
Get schema options for LLM model selection dropdown.
Returns a copy of cached schema options that are refreshed when the registry is
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return list(_schema_options)
def get_llm_discriminator_mapping() -> dict[str, str]:
"""
Get discriminator mapping for LLM models.
Returns a copy of cached discriminator mapping that is refreshed when the registry
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
"""
# Return a copy to prevent external mutation
return dict(_discriminator_mapping)
def get_dynamic_model_slugs() -> set[str]:
"""Get all dynamic model slugs from the registry."""
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
"""Iterate over all dynamic models in the registry."""
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)
def get_default_model_slug() -> str | None:
"""
Get the default model slug to use for block defaults.
Returns the recommended model if set (configured via admin UI),
otherwise returns the first enabled model alphabetically.
Returns None if no models are available or enabled.
"""
# Return the recommended model if one is set and enabled
for model in _dynamic_models.values():
if model.is_recommended and model.is_enabled:
return model.slug
# No recommended model set - find first enabled model alphabetically
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
logger.warning(
"No recommended model set, using '%s' as default",
model.slug,
)
return model.slug
# No enabled models available
if _dynamic_models:
logger.error(
"No enabled models found in registry (%d models registered but all disabled)",
len(_dynamic_models),
)
else:
logger.error("No models registered in LLM registry")
return None

View File

@@ -1,130 +0,0 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data.llm_registry.registry import (
get_all_model_slugs_for_validation,
get_default_model_slug,
get_llm_discriminator_mapping,
get_llm_model_schema_options,
)
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options from the registry.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = get_llm_model_schema_options()
if not fresh_options:
return
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
all_known_slugs = get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
# Set the default value from the registry (gpt-4o if available, else first enabled)
# This ensures new blocks have a sensible default pre-selected
default_slug = get_default_model_slug()
if default_slug:
field_schema["default"] = default_slug
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = get_llm_discriminator_mapping()
if fresh_mapping is not None:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any], model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name,
exc,
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name,
exc,
)

View File

@@ -40,7 +40,6 @@ from pydantic_core import (
)
from typing_extensions import TypedDict
from backend.data.llm_registry import update_schema_with_llm_registry
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
@@ -545,9 +544,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
# Do not return anything, just mutate schema in place
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
@@ -669,10 +666,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
if not (self.discriminator and self.discriminator_mapping):
return self
try:
provider = self.discriminator_mapping[discriminator_value]
except KeyError:
raise ValueError(
f"Model '{discriminator_value}' is not supported. "
"It may have been deprecated. Please update your agent configuration."
)
return CredentialsFieldInfo(
credentials_provider=frozenset(
[self.discriminator_mapping[discriminator_value]]
),
credentials_provider=frozenset([provider]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
discriminator=self.discriminator,
@@ -696,20 +699,16 @@ def CredentialsField(
This is enforced by the `BlockSchema` base class.
"""
# Build field_schema_extra - always include discriminator and mapping if discriminator is set
field_schema_extra: dict[str, Any] = {}
# Always include discriminator if provided
if discriminator is not None:
field_schema_extra["discriminator"] = discriminator
# Always include discriminator_mapping when discriminator is set (even if empty initially)
field_schema_extra["discriminator_mapping"] = discriminator_mapping or {}
# Include other optional fields (only if not None)
if required_scopes:
field_schema_extra["credentials_scopes"] = list(required_scopes)
if discriminator_values:
field_schema_extra["discriminator_values"] = discriminator_values
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
"discriminator_values": discriminator_values,
}.items()
if v is not None
}
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:

View File

@@ -41,6 +41,7 @@ FrontendOnboardingStep = Literal[
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.AGENT_INPUT,
OnboardingStep.CONGRATS,
OnboardingStep.VISIT_COPILOT,
OnboardingStep.MARKETPLACE_VISIT,
OnboardingStep.BUILDER_OPEN,
]
@@ -122,6 +123,9 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
async def _reward_user(user_id: str, onboarding: UserOnboarding, step: OnboardingStep):
reward = 0
match step:
# Welcome bonus for visiting copilot ($5 = 500 credits)
case OnboardingStep.VISIT_COPILOT:
reward = 500
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
# This is seen as a reward for the GET_RESULTS step in the wallet

View File

@@ -0,0 +1,276 @@
"""
Database CRUD operations for User Workspace.
This module provides functions for managing user workspaces and workspace files.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.models import UserWorkspace, UserWorkspaceFile
from prisma.types import UserWorkspaceFileWhereInput
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
"""
Get user's workspace, creating one if it doesn't exist.
Uses upsert to handle race conditions when multiple concurrent requests
attempt to create a workspace for the same user.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance
"""
workspace = await UserWorkspace.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id},
"update": {}, # No updates needed if exists
},
)
return workspace
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
"""
Get user's workspace if it exists.
Args:
user_id: The user's ID
Returns:
UserWorkspace instance or None
"""
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
async def create_workspace_file(
workspace_id: str,
file_id: str,
name: str,
path: str,
storage_path: str,
mime_type: str,
size_bytes: int,
checksum: Optional[str] = None,
metadata: Optional[dict] = None,
) -> UserWorkspaceFile:
"""
Create a new workspace file record.
Args:
workspace_id: The workspace ID
file_id: The file ID (same as used in storage path for consistency)
name: User-visible filename
path: Virtual path (e.g., "/documents/report.pdf")
storage_path: Actual storage path (GCS or local)
mime_type: MIME type of the file
size_bytes: File size in bytes
checksum: Optional SHA256 checksum
metadata: Optional additional metadata
Returns:
Created UserWorkspaceFile instance
"""
# Normalize path to start with /
if not path.startswith("/"):
path = f"/{path}"
file = await UserWorkspaceFile.prisma().create(
data={
"id": file_id,
"workspaceId": workspace_id,
"name": name,
"path": path,
"storagePath": storage_path,
"mimeType": mime_type,
"sizeBytes": size_bytes,
"checksum": checksum,
"metadata": SafeJson(metadata or {}),
}
)
logger.info(
f"Created workspace file {file.id} at path {path} "
f"in workspace {workspace_id}"
)
return file
async def get_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by ID.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
UserWorkspaceFile instance or None
"""
where_clause: dict = {"id": file_id, "isDeleted": False}
if workspace_id:
where_clause["workspaceId"] = workspace_id
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
async def get_workspace_file_by_path(
workspace_id: str,
path: str,
) -> Optional[UserWorkspaceFile]:
"""
Get a workspace file by its virtual path.
Args:
workspace_id: The workspace ID
path: Virtual path
Returns:
UserWorkspaceFile instance or None
"""
# Normalize path
if not path.startswith("/"):
path = f"/{path}"
return await UserWorkspaceFile.prisma().find_first(
where={
"workspaceId": workspace_id,
"path": path,
"isDeleted": False,
}
)
async def list_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
limit: Optional[int] = None,
offset: int = 0,
) -> list[UserWorkspaceFile]:
"""
List files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/documents/")
include_deleted: Whether to include soft-deleted files
limit: Maximum number of files to return
offset: Number of files to skip
Returns:
List of UserWorkspaceFile instances
"""
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().find_many(
where=where_clause,
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def count_workspace_files(
workspace_id: str,
path_prefix: Optional[str] = None,
include_deleted: bool = False,
) -> int:
"""
Count files in a workspace.
Args:
workspace_id: The workspace ID
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
include_deleted: Whether to include soft-deleted files
Returns:
Number of files
"""
where_clause: dict = {"workspaceId": workspace_id}
if not include_deleted:
where_clause["isDeleted"] = False
if path_prefix:
# Normalize prefix
if not path_prefix.startswith("/"):
path_prefix = f"/{path_prefix}"
where_clause["path"] = {"startswith": path_prefix}
return await UserWorkspaceFile.prisma().count(where=where_clause)
async def soft_delete_workspace_file(
file_id: str,
workspace_id: Optional[str] = None,
) -> Optional[UserWorkspaceFile]:
"""
Soft-delete a workspace file.
The path is modified to include a deletion timestamp to free up the original
path for new files while preserving the record for potential recovery.
Args:
file_id: The file ID
workspace_id: Optional workspace ID for validation
Returns:
Updated UserWorkspaceFile instance or None if not found
"""
# First verify the file exists and belongs to workspace
file = await get_workspace_file(file_id, workspace_id)
if file is None:
return None
deleted_at = datetime.now(timezone.utc)
# Modify path to free up the unique constraint for new files at original path
# Format: {original_path}__deleted__{timestamp}
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
updated = await UserWorkspaceFile.prisma().update(
where={"id": file_id},
data={
"isDeleted": True,
"deletedAt": deleted_at,
"path": deleted_path,
},
)
logger.info(f"Soft-deleted workspace file {file_id}")
return updated
async def get_workspace_total_size(workspace_id: str) -> int:
"""
Get the total size of all files in a workspace.
Args:
workspace_id: The workspace ID
Returns:
Total size in bytes
"""
files = await list_workspace_files(workspace_id)
return sum(file.sizeBytes for file in files)

View File

@@ -17,6 +17,7 @@ from backend.data.analytics import (
get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring,
)
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
@@ -219,6 +220,9 @@ class DatabaseManager(AppService):
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
@@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details

View File

@@ -1,66 +0,0 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.data import db, llm_registry
from backend.data.block import BlockSchema, initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Initialize blocks (internally refreshes LLM registry and costs)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -236,7 +236,14 @@ async def execute_node(
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Create node-specific execution context to avoid race conditions
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
execution_context = execution_context.model_copy(
update={"node_id": node_id, "node_exec_id": node_exec_id}
)
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,
@@ -702,20 +709,6 @@ class ExecutionProcessor:
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)

View File

@@ -24,11 +24,9 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
@@ -38,7 +36,11 @@ from backend.monitoring import (
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.clients import (
get_database_manager_async_client,
get_database_manager_client,
get_scheduler_client,
)
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import (
GraphNotFoundError,
@@ -148,6 +150,7 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time()
db = get_database_manager_async_client()
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -157,7 +160,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await increment_onboarding_runs(args.user_id)
await db.increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -246,8 +249,13 @@ def cleanup_expired_files():
def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database."""
# Wait for completion
run_async(cleanup_expired_oauth_tokens())
async def _cleanup():
db = get_database_manager_async_client()
return await db.cleanup_expired_oauth_tokens()
run_async(_cleanup())
def execution_accuracy_alerts():

Some files were not shown because too many files have changed in this diff Show More