Compare commits

...

65 Commits

Author SHA1 Message Date
Lluis Agusti
9b98b2df40 chore: wip 2026-01-17 09:08:15 +07:00
Swifty
e55f05c7a8 feat(backend): add chat search tools and BM25 reranking (#11782)
This PR adds new chat tools for searching blocks and documentation,
along with BM25 reranking for improved search relevance.

### Changes 🏗️

**New Chat Tools:**
- `find_block` - Search for available blocks by name/description using
hybrid search
- `run_block` - Execute a block directly with provided inputs and
credentials
- `search_docs` - Search documentation with section-level granularity  
- `get_doc_page` - Retrieve full documentation page content

**Search Improvements:**
- Added BM25 reranking to hybrid search for better lexical relevance
- Documentation handler now chunks markdown by headings (##) for
finer-grained embeddings
- Section-based content IDs (`doc_path::section_index`) for precise doc
retrieval
- Startup embedding backfill in scheduler for immediate searchability

**Other Changes:**
- New response models for block and documentation search results
- Updated orphan cleanup to handle section-based doc embeddings
- Added `rank-bm25` dependency for BM25 scoring
- Removed max message limit check in chat service

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run find_block tool to search for blocks (e.g., "current time")
  - [x] Run run_block tool to execute a found block
  - [x] Run search_docs tool to search documentation
  - [x] Run get_doc_page tool to retrieve full doc content
- [x] Verify BM25 reranking improves search relevance for exact term
matches
  - [x] Verify documentation sections are properly chunked and embedded

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

**Dependencies added:** `rank-bm25` for BM25 scoring algorithm
2026-01-16 16:18:10 +01:00
Swifty
4a9b13acb6 feat(frontend): extract frontend changes from hackathon/copilot branch (#11717)
Frontend changes extracted from the hackathon/copilot branch for the
copilot feature development.

### Changes 🏗️

- New Chat system with contextual components (`Chat`, `ChatDrawer`,
`ChatContainer`, `ChatMessage`, etc.)
- Form renderer system with RJSF v6 integration and new input renderers
- Enhanced credentials management with improved OAuth flow and
credential selection
- New output renderers for various content types (Code, Image, JSON,
Markdown, Text, Video)
- Scrollable tabs component for better UI organization
- Marketplace update notifications and publishing workflow improvements
- Draft recovery feature with IndexedDB persistence
- Safe mode toggle functionality
- Various UI/UX improvements across the platform

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [ ] Test new Chat components functionality
  - [ ] Verify form renderer with various input types
  - [ ] Test credential management flows
  - [ ] Verify output renderers display correctly
  - [ ] Test draft recovery feature

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

---------

Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-01-16 22:15:39 +07:00
Zamil Majdy
5ff669e999 fix(backend): Make Redis connection lazy in cache module (#11775)
## Summary
- Makes Redis connection lazy in the cache module - connection is only
established when `shared_cache=True` is actually used
- Fixes DatabaseManager failing to start because it imports
`onboarding.py` which imports `cache.py`, triggering Redis connection at
module load time even though it only uses in-memory caching

## Root Cause
Commit `b01ea3fcb` (merged today) added `increment_onboarding_runs` to
DatabaseManager, which imports from `onboarding.py`. That module imports
`@cached` decorator from `cache.py`, which was creating a Redis
connection at module import time:

```python
# Old code - ran at import time!
redis = Redis(connection_pool=_get_cache_pool())
```

Since `onboarding.py` only uses `@cached(shared_cache=False)` (in-memory
caching), it doesn't actually need Redis. But the import triggered the
connection attempt.

## Changes
- Wrapped Redis connection in a singleton class with lazy initialization
- Connection is only established when `_get_redis()` is first called
(i.e., when `shared_cache=True` is used)
- Services using only in-memory caching can now import `cache.py`
without Redis configuration

## Test plan
- [ ] Services using `shared_cache=False` work without Redis configured
- [ ] Services using `shared_cache=True` still work correctly with Redis
- [ ] Existing cache tests pass

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 14:28:36 +00:00
Abhimanyu Yadav
ec03a13e26 fix(frontend): improve history tracking, error handling (#11786)
### Changes 🏗️

- **Improved Error Handling**: Enhanced error handling in
`useRunInputDialog.ts` to properly handle cases where node errors are
empty or undefined
- **Fixed Node Collision Resolution**: Updated `Flow.tsx` to use the
current state from the store instead of stale props
- **Enhanced History Management**:
    - Added proper state tracking for edge removal operations
    - Improved undo/redo functionality to prevent duplicate states
- Fixed edge case where history wasn't properly tracked during node
dragging
- **UI Improvements**:
- Fixed potential null reference in NodeHeader when accessing agent_name
    - Added placeholder for GoogleDrivePicker in INPUT mode
    - Fixed spacing in ArrayFieldTemplate
- **Bug Fixes**:
    - Added proper state tracking before modifying nodes/edges
    - Fixed history tracking to avoid redundant states
    - Improved collision detection and resolution

### Checklist ���

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Test undo/redo functionality after adding, removing, and moving
nodes
    - [x] Test edge creation and deletion with history tracking
    - [x] Verify error handling when graph validation fails
    - [x] Test Google Drive picker in different UI modes
    - [x] Verify node collision resolution works correctly
2026-01-16 13:34:57 +00:00
Abhimanyu Yadav
b08851f5d7 feat(frontend): improve GoogleDrivePickerField with input mode support and array field spacing (#11780)
### Changes 🏗️

- Added a placeholder UI for Google Drive Picker in INPUT block type
- Improved detection of Google Drive file objects in schema validation
- Extracted `isGoogleDrivePickerSchema` function for better code
organization
- Added spacing between array field elements with a gap-2 class
- Added debug logging for preprocessed schema in FormRenderer

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified Google Drive Picker shows placeholder in INPUT blocks
  - [x] Confirmed array field elements have proper spacing
  - [x] Tested that Google Drive file objects are properly detected
2026-01-16 13:02:36 +00:00
Abhimanyu Yadav
8b1720e61d feat(frontend): improve graph validation error handling and node navigation (#11779)
### Changes 🏗️

- Enhanced error handling for graph validation failures with detailed
user feedback
- Added automatic viewport navigation to the first node with errors when
validation fails
- Improved node title display to prioritize agent_name from hardcoded
values
- Removed console.log debugging statement from OutputHandler
- Added ApiError import and improved error type handling
- Reorganized imports for better code organization

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Create a graph with intentional validation errors and verify error
messages display correctly
- [x] Verify the viewport automatically navigates to the first node with
errors
- [x] Check that node titles correctly display customized names or agent
names
- [x] Test error recovery by fixing validation errors and successfully
running the graph
2026-01-16 11:14:00 +00:00
Abhimanyu Yadav
aa5a039c5e feat(frontend): add special rendering for NOTE UI type in FieldTemplate (#11771)
### Changes 🏗️

Added support for Note blocks in the FieldTemplate component by:
- Importing the BlockUIType enum from the build components types
- Extracting the uiType from the registry.formContext
- Adding a conditional rendering check that returns children directly
when the uiType is BlockUIType.NOTE

This change allows Note blocks to render without the standard field
template wrapper, providing a cleaner display for note-type content.


![Screenshot 2026-01-15 at
1.01.03 PM.png](https://app.graphite.com/user-attachments/assets/7d654eed-abbe-4ec3-9c80-24a77a8373e3.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Created a Note block and verified it renders correctly without
field template wrapper
- [x] Confirmed other block types still render with proper field
template
- [x] Verified that Note blocks maintain proper functionality in the
node graph
2026-01-16 11:10:21 +00:00
Zamil Majdy
8b83bb8647 feat(backend): unified hybrid search with embedding backfill for all content types (#11767)
## Summary

This PR extends the embedding system to support **blocks** and
**documentation** content types in addition to store agents, and
introduces **unified hybrid search** across all content types using a
single `UnifiedContentEmbedding` table.

### Key Changes

1. **Unified Hybrid Search Architecture**
   - Added `search` tsvector column to `UnifiedContentEmbedding` table
- New `unified_hybrid_search()` function searches across all content
types (agents, blocks, docs)
- Updated `hybrid_search()` for store agents to use
`UnifiedContentEmbedding.search`
   - Removed deprecated `search` column from `StoreListingVersion` table

2. **Pluggable Content Handler Architecture**
   - Created abstract `ContentHandler` base class for extensibility
- Implemented handlers: `StoreAgentHandler`, `BlockHandler`,
`DocumentationHandler`
   - Registry pattern for easy addition of new content types

3. **Block Embeddings**
   - Discovers all blocks using `get_blocks()`
- Extracts searchable text from: name, description, categories,
input/output schemas

4. **Documentation Embeddings**
   - Scans `/docs/` directory for `.md` and `.mdx` files
   - Extracts title from first `#` heading or uses filename as fallback

5. **Hybrid Search Graceful Degradation**
- Falls back to lexical-only search if query embedding generation fails
   - Redistributes semantic weight proportionally to other components
   - Logs warning instead of throwing error

6. **Database Migrations**
- `20260115200000_add_unified_search_tsvector`: Adds search column to
UnifiedContentEmbedding with auto-update trigger
- `20260115210000_remove_storelistingversion_search`: Removes deprecated
search column and updates StoreAgent view

7. **Orphan Cleanup**
- `cleanup_orphaned_embeddings()` removes embeddings for deleted content
   - Always runs after backfill, even at 100% coverage

### Review Comments Addressed

-  SQL parameter index bug when user_id provided (embeddings.py)
-  Early return skipping cleanup at 100% coverage (scheduler.py)
-  Inconsistent return structure across code paths (scheduler.py)
-  SQL UNION syntax error - added parentheses for ORDER BY/LIMIT
(hybrid_search.py)
-  Version numeric ordering in aggregations (migration)
-  Embedding dimension uses EMBEDDING_DIM constant

### Files Changed

- `backend/api/features/store/content_handlers.py` (NEW): Handler
architecture
- `backend/api/features/store/embeddings.py`: Refactored to use handlers
- `backend/api/features/store/hybrid_search.py`: Unified search +
graceful degradation
- `backend/executor/scheduler.py`: Process all content types, consistent
returns
- `migrations/20260115200000_add_unified_search_tsvector/`: Add tsvector
to unified table
- `migrations/20260115210000_remove_storelistingversion_search/`: Remove
old search column
- `schema.prisma`: Updated UnifiedContentEmbedding and
StoreListingVersion models
- `*_test.py`: Added tests for unified_hybrid_search

## Test Plan

1.  All tests passing on Python 3.11, 3.12, 3.13
2.  Types check passing
3.  CodeRabbit and Sentry reviews addressed
4. Deploy to staging and verify:
   - Backfill job processes all content types
   - Search results include blocks and docs
   - Search works without OpenAI API (graceful degradation)

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 09:47:19 +01:00
Nicholas Tindle
e80e4d9cbb ci: update dev from gitbook (#11757)
<!-- Clearly explain the need for these changes: -->
gitbook changes via ui

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Docs sync from GitBook**
> 
> - Updates `docs/home/README.md` with a new Developer Platform landing
page (cards, links to Platform, Integrations, Contribute, Discord,
GitHub) and metadata/cover settings
> - Adds `docs/home/SUMMARY.md` defining the table of contents linking
to `README.md`
> - No application/runtime code changes
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
446c71fec8. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
2026-01-15 20:02:48 +00:00
Ubbe
375d33cca9 fix(frontend): agent credentials improvements (#11763)
## Changes 🏗️

### System credentials in Run Modal

We had the issue that "system" credentials were mixed with "user"
credentials in the run agent modal:

#### Before

<img width="400" height="466" alt="Screenshot 2026-01-14 at 19 05 56"
src="https://github.com/user-attachments/assets/9d1ee766-5004-491f-ae14-a0cf89a9118e"
/>

This created confusion among the users. This "system" credentials are
supplied by AutoGPT ( _most of the time_ ) and a user running an agent
should not bother with them ( _unless they want to change them_ ). For
example in this case, the credential that matters is the **Google** one
🙇🏽

### After

<img width="400" height="350" alt="Screenshot 2026-01-14 at 19 04 12"
src="https://github.com/user-attachments/assets/e2bbc015-ce4c-496c-a76f-293c01a11c6f"
/>

<img width="400" height="672" alt="Screenshot 2026-01-14 at 19 04 19"
src="https://github.com/user-attachments/assets/d704dae2-ecb2-4306-bd04-3d812fed4401"
/>

"System" credentials are collapsed by default, reducing noise in the
Task Credentials section. The user can still see and change them by
expanding the accordion.

<img width="400" height="190" alt="Screenshot 2026-01-14 at 19 04 27"
src="https://github.com/user-attachments/assets/edc69612-4588-48e4-981a-f59c26cfa390"
/>

If some "system" credentials are missing, there is a red label
indicating so, it wasn't that obvious with the previous implementation,

<img width="400" height="309" alt="Screenshot 2026-01-14 at 19 04 30"
src="https://github.com/user-attachments/assets/f27081c7-40ad-4757-97b3-f29636616fc2"
/>

### New endpoint

There is a new REST endpoint, `GET /providers/system`, to list system
credential providers so it is easy to access in the Front-end to group
them together vs user ones.

### Other improvements

#### `<CredentialsInput />` refinements

<img width="715" height="200" alt="Screenshot 2026-01-14 at 19 09 31"
src="https://github.com/user-attachments/assets/01b39b16-25f3-428d-a6c8-da608038a38b"
/>

Use a normal browser `<select>` for the Credentials Dropdown ( _when you
have more than 1 for a provider_ ). This simplifies the UI shennagians a
lot and provides a better UX in 📱 ( _eventually we should move all our
selects to the native ones as they are much better for mobile and touch
screens and less code to maintain our end_ ).

I also renamed some files for clarity and tidied up some of the existing
logic.

#### Other

- Fix **Open telemetry** warnings on the server console by making the
packages external
- Fix `require-in-the-middle` console warnings
- Prettier tidy ups

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and test the above
2026-01-15 17:44:44 +07:00
Swifty
3b1b2fe30c feat(backend): Extract backend copilot/chat enhancements from hackathon (#11719)
This PR extracts backend changes from the hackathon/copilot branch,
adding enhanced chat capabilities, agent management tools, store
embeddings, and hybrid search functionality.

### Changes 🏗️

**Chat Features:**
- Added chat database layer (`db.py`) for conversation and message
persistence
- Extended chat models with new types and response structures
- New onboarding system prompt for guided user experiences
- Enhanced chat routes with additional endpoints
- Expanded chat service with more capabilities

**Chat Agent Tools:**
- `agent_output.py` - Handle agent execution outputs
- `create_agent.py` - Tool for creating new agents via chat
- `edit_agent.py` - Tool for modifying existing agents
- `find_library_agent.py` - Search and discover library agents
- Enhanced `run_agent.py` with additional functionality
- New `models.py` for shared tool types

**Store Enhancements:**
- `embeddings.py` - Vector embeddings support for semantic search
- `hybrid_search.py` - Combined keyword and semantic search
- `backfill_embeddings.py` - Utility for backfilling existing data
- Updated store database operations

**Admin:**
- Enhanced store admin routes

**Data Layer:**
- New `understanding.py` module for agent understanding/context

**Database Migrations:**
- `add_chat_tables` - Chat conversation and message tables
- `add_store_embeddings` - Embeddings storage for store items
- `enhance_search` - Search index improvements

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Chat endpoints respond correctly
  - [x] Agent tools (create/edit/find/run) function properly
  - [x] Store embeddings and hybrid search work
  - [x] Database migrations apply cleanly

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

---------

Co-authored-by: Torantulino <40276179@live.napier.ac.uk>
2026-01-15 11:11:36 +01:00
Abhimanyu Yadav
af63b3678e feat(frontend): hide children of connected array and object fields
(#11770)

### Changes 🏗️

- Added conditional rendering for array and object field children based
on connection status
- Implemented `shouldShowChildren` logic in `ArrayFieldTemplate` and
`ObjectFieldTemplate` components
- Modified the `shouldShowChildren` condition in `FieldTemplate` to
handle different schema types
- Imported and utilized `cleanUpHandleId` and `useEdgeStore` to check if
inputs are connected
- Added connection status checks to hide form fields when their inputs
are connected to other nodes

![Screenshot 2026-01-15 at
12.55.32 PM.png](https://app.graphite.com/user-attachments/assets/d3fffade-872e-4fd8-a347-28d1bae3072e.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified that object and array fields hide their children when
connected to other nodes
- [x] Confirmed that unconnected fields display their children properly
- [x] Tested with various schema types to ensure correct rendering
behavior
- [x] Checked that the connection status is properly detected and
applied
2026-01-15 08:10:52 +00:00
Abhimanyu Yadav
631f1bd50a feat(frontend): add interactive tutorial for the new builder interface (#11458)
### Changes 🏗️

This PR adds a comprehensive interactive tutorial for the new Builder UI
to help users learn how to create agents. Key changes include:

- Added a tutorial button to the canvas controls that launches a
step-by-step guide
- Created a Shepherd.js-based tutorial with multiple steps covering:
    - Adding blocks from the Block Menu
    - Understanding input and output handles
    - Configuring block values
    - Connecting blocks together
    - Saving and running agents
- Added data-id attributes to key UI elements for tutorial targeting
- Implemented tutorial state management with a new tutorialStore
- Added helper functions for tutorial navigation and block manipulation
- Created CSS styles for tutorial tooltips and highlights
- Integrated with the Run Input dialog to support tutorial flow
- Added prefetching of tutorial blocks for better performance


https://github.com/user-attachments/assets/3db964b3-855c-4fcc-aa5f-6cd74ab33d7d


### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
    - [x] Complete the tutorial from start to finish
    - [x] Test tutorial on different screen sizes
    - [x] Verify all tutorial steps work correctly
    - [x] Ensure tutorial can be canceled and restarted
- [x] Check that tutorial doesn't interfere with normal builder
functionality
2026-01-15 07:47:27 +00:00
Swifty
5ac941fe2f feat(backend): add hybrid search for store listings, docs and blocks (#11721)
This PR adds hybrid search functionality combining semantic embeddings
with traditional text search for improved store listing discovery.

### Changes 🏗️

- Add `embeddings.py` - OpenAI-based embedding generation and similarity
search
- Add `hybrid_search.py` - Combines vector similarity with text matching
for better search results
- Add `backfill_embeddings.py` - Script to generate embeddings for
existing store listings
- Update `db.py` - Integrate hybrid search into store database queries
- Update `schema.prisma` - Add embedding storage fields and indexes
- Add migrations for embedding columns and HNSW index for vector search

### Architecture Decisions 🏛️

**Fail-Fast Approach (No Silent Fallbacks)**

We explicitly chose NOT to implement graceful degradation when hybrid
search fails. Here's why:

 **Benefits:**
- Errors surface immediately → faster fixes
- Tests verify hybrid search actually works (not just fallback)
- Consistent search quality for all users
- Forces proper infrastructure setup (API keys, database)

 **Why Not Fallback:**
- Silent degradation hides production issues
- Users get inconsistent results without knowing why
- Tests can pass even when hybrid search is broken
- Reduces operational visibility

**How We Prevent Failures:**
1. Embedding generation in approval flow (db.py:1545)
2. Error logging with `logger.error` (not warning)
3. Clear error messages (ValueError explains what's wrong)
4. Comprehensive test coverage (9/9 tests passing)

If embeddings fail, it indicates a real infrastructure issue (missing
API key, OpenAI down, database issues) that needs immediate attention,
not silent degradation.

### Test Coverage 

**All tests passing (1625 total):**
- 9/9 hybrid_search tests (including fail-fast validation)
- 3/3 db search integration tests
- Full schema compatibility (public/platform schemas)
- Error handling verification

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test hybrid search returns relevant results
  - [x] Test embedding generation for new listings
  - [x] Test backfill script on existing data
  - [x] Verify search performance with embeddings
  - [x] Test fail-fast behavior when embeddings unavailable

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] Configuration: Requires `openai_internal_api_key` in secrets

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 04:17:03 +00:00
Reinier van der Leer
b01ea3fcbd fix(backend/executor): Centralize increment_runs calls & make add_graph_execution more robust (#11764)
[OPEN-2946: \[Scheduler\] Error executing graph <graph_id> after 19.83s:
ClientNotConnectedError: Client is not connected to the query engine,
you must call `connect()` before attempting to query
data.](https://linear.app/autogpt/issue/OPEN-2946)

- Follow-up to #11375
  <sub>(broken `increment_runs` call)</sub>
- Follow-up to #11380
  <sub>(direct `get_graph_execution` call)</sub>

### Changes 🏗️

- Move `increment_runs` call from `scheduler._execute_graph` to
`executor.utils.add_graph_execution` so it can be made through
`DatabaseManager`
  - Add `increment_onboarding_runs` to `DatabaseManager`
- Remove now-redundant `increment_onboarding_runs` calls in other places
- Make `add_graph_execution` more resilient
  - Split up large try/except block
  - Fix direct `get_graph_execution` call

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI + a thorough review
2026-01-15 04:08:19 +00:00
Reinier van der Leer
3b09a94e3f feat(frontend/builder): Add sub-graph update UX (#11631)
[OPEN-2743: Ability to Update Sub-Agents in Graph (Without
Re-Adding)](https://linear.app/autogpt/issue/OPEN-2743/ability-to-update-sub-agents-in-graph-without-re-adding)

Updating sub-graphs is a cumbersome experience at the moment, this
should help. :)

Demo in Builder v2:


https://github.com/user-attachments/assets/df564f32-4d1d-432c-bb91-fe9065068360


https://github.com/user-attachments/assets/f169471a-1f22-46e9-a958-ddb72d3f65af


### Changes 🏗️

- Add sub-graph update banner with I/O incompatibility notification and
resolution mode
  - Red visual indicators for broken inputs/outputs and edges
  - Update bars and tooltips show compatibility details
- Sub-agent update UI with compatibility checks, incompatibility dialog,
and guided resolution workflow
- Resolution mode banner guiding users to remove incompatible
connections
- Visual controls to stage/apply updates and auto-apply when broken
connections are fixed
  
  Technical:
- Builder v1: Add `CustomNode` > `IncompatibilityDialog` +
`SubAgentUpdateBar` sub-components
- Builder v2: Add `SubAgentUpdateFeature` + `ResolutionModeBar` +
`IncompatibleUpdateDialog` + `useSubAgentUpdateState` sub-components
  - Add `useSubAgentUpdate` hook

- Related fixes in Builder v1:
  - Fix static edges not rendering as such
  - Fix edge styling not applying
- Related fixes in Builder v2:
  - Fix excess spacing for nested node input fields

Other:
- "Retry" button in error view now reloads the page instead of
navigating to `/marketplace`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI for existing frontend UX flows
- [x] Updating to a new sub-agent version with compatibility issues: UX
flow works
- [x] Updating to a new sub-agent version with *no* compatibility
issues: works
  - [x] Designer approves of the look

---------

Co-authored-by: abhi1992002 <abhimanyu1992002@gmail.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-01-14 13:25:20 +00:00
Zamil Majdy
61efee4139 fix(frontend): Remove hardcoded bypass of billing feature flag (#11762)
## Summary

Fixes a critical security issue where the billing button in the settings
sidebar was always visible to all users, bypassing the
`ENABLE_PLATFORM_PAYMENT` feature flag.

## Changes 🏗️

- Removed hardcoded `|| true` condition in
`frontend/src/app/(platform)/profile/(user)/layout.tsx:32` that was
bypassing the feature flag check
- The billing button is now properly gated by the
`ENABLE_PLATFORM_PAYMENT` feature flag as intended

## Root Cause

The `|| true` was accidentally left in commit
3dbc03e488 (PR #11617 - OAuth API & Single
Sign-On) from December 19, 2025. It was likely added temporarily during
development/testing to always show the billing button, but was not
removed before merging.

## Test Plan

1. Verify feature flag is set to disabled in LaunchDarkly
2. Navigate to settings page (`/profile/settings`)
3. Confirm billing button is NOT visible in the sidebar
4. Enable feature flag in LaunchDarkly
5. Refresh page and confirm billing button IS now visible
6. Verify billing page (`/profile/credits`) is still accessible via
direct URL when feature flag is disabled

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan

Fixes SECRT-1791

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* The Billing link in the profile sidebar now respects the payment
feature flag configuration and will only display when payment
functionality is enabled.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-14 03:28:36 +00:00
Bently
e539280e98 fix(blocks): set User-Agent header and URL-encode topic in GetWikipediaSummaryBlock (#11754)
The GetWikipediaSummaryBlock was returning HTTP 403 errors from
Wikipedia's API because it wasn't explicitly setting a User-Agent header
that complies with https://wikitech.wikimedia.org/wiki/Robot_policy.
Additionally, topics with spaces or special characters would cause
malformed URLs.

Fixes: OPEN-2889

Changes 🏗️

- URL-encode the topic parameter using urllib.parse.quote() to handle
spaces and special characters
- Explicitly set required headers per Wikimedia robot policy:
- User-Agent: Platform default user agent (includes app name, URL, and
contact email)
- Accept-Encoding: gzip, deflate: Recommended by Wikimedia to reduce
bandwidth
- Updated test mock to match the new function signature

Checklist 📋

For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify code passes syntax check
  - [x] Verify code passes ruff linting
- [x] Create an agent using GetWikipediaSummaryBlock with a topic
containing spaces (e.g., "Artificial Intelligence")
  - [x] Verify the block returns a Wikipedia summary without 403 errors

For configuration changes:

- .env.default is updated or already compatible with my changes
- docker-compose.yml is updated or already compatible with my changes
- I have included a list of my configuration changes in the PR
description (under Changes)
.
N/A - No configuration changes required.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved Wikipedia API requests by adding compatible request headers
(including a proper user agent and encoding acceptance) for more
reliable responses.
* Enhanced handling of search topics by URL-encoding terms so queries
with spaces or special characters return correct results.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-13 12:24:51 +00:00
Toran Bruce Richards
db8b43bb3d feat(blocks): Add WordPress Get All Posts block and Publish Post draft toggle (#11003)
**Implements issue #11002**

This PR adds WordPress post management functionality and improves error
handling in DataForSEO blocks.

### Changes 🏗️

1. **New WordPress Blocks:**
- Added `WordPressGetAllPostsBlock` - Fetches posts from WordPress sites
with filtering and pagination support
- Enhanced `WordPressCreatePostBlock` with `publish_as_draft` toggle to
control post publication status

2. **WordPress API Enhancements:**
- Added `get_posts()` function in `_api.py` to retrieve posts with
filtering by status
- Added `PostsResponse` model for handling WordPress posts list API
responses
- Support for pagination with `number` and `offset` parameters (max 100
posts per request)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  
  **Test Plan:**
- [x] Test `WordPressGetAllPostsBlock` with valid WordPress credentials
  - [x] Verify filtering posts by status (publish, draft, pending, etc.)
  - [x] Test pagination with different number and offset values
- [x] Test `WordPressCreatePostBlock` with publish_as_draft=True to
create draft posts
- [x] Test `WordPressCreatePostBlock` with publish_as_draft=False to
publish posts publicly

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

**Note:** No configuration changes were required for this PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added a WordPress “Get All Posts” block to fetch posts with optional
status filtering and pagination; returns total found and post details.
* **Enhancements**
* WordPress “Create Post” block now supports a “Publish as draft”
option, allowing posts to be created as drafts or published immediately.
* WordPress blocks are now surfaced consistently in the block catalog
for easier use.
* **Error Handling**
* Clearer error messages when fetching posts fails, aiding
troubleshooting.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces WordPress post listing and improves post creation and API
robustness.
> 
> - Adds `WordPressGetAllPostsBlock` to fetch posts with optional
`status` filter and pagination (`number`, `offset`); outputs `found`,
`posts`, and streams each `post`
> - Enhances `WordPressCreatePostBlock` with `publish_as_draft` input
and adds `site` to outputs; sets `status` accordingly
> - WordPress API updates in `_api.py`: new `get_posts`, `Post`,
`PostsResponse`, and `normalize_site`; apply
`Requests(raise_for_status=False)` across OAuth/token/info and post
creation; better error propagation
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
10be1c4709. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <Torantulino@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 19:57:47 +00:00
Abhimanyu Yadav
923d8baedc feat(frontend): add JsonTextField component for complex nested form data (#11752)
### Changes 🏗️

- Added a new `JsonTextField` component to handle complex nested JSON
types (objects/arrays inside other objects/arrays)
- Created helper functions for JSON parsing, validation, and formatting
- Implemented `useJsonTextField` hook to manage state and validation
- Enhanced `generateUiSchemaForCustomFields` to detect nested complex
types and render them as JSON text fields
- Updated `TextInputExpanderModal` to support JSON-specific styling
- Added `JSON_TEXT_FIELD_ID` constant to custom registry for field
identification

This change improves the user experience by preventing deeply nested
form UIs. Instead, complex nested structures are presented as editable
JSON text fields with proper validation and formatting.

### Before

![Screenshot 2026-01-12 at
1.07.54 PM.png](https://app.graphite.com/user-attachments/assets/dc2b96cc-562a-4e6b-8278-76de941e3bd9.png)

### After

![Screenshot 2026-01-12 at
12.35.19 PM.png](https://app.graphite.com/user-attachments/assets/ea0028a5-c119-43c3-8100-b103484e0b54.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test with simple JSON objects in forms
  - [x] Test with nested arrays and objects
  - [x] Test with anyOf/oneOf schemas containing complex types
  - [x] Test the expander modal with JSON content

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* New JSON text field with expandable modal editor, inline validation,
and helpful placeholders.
* Complex nested objects/arrays now render as JSON fields to simplify
editing.
* Modal editor uses monospace, smaller text when editing JSON for
improved readability.

* **Chores**
* Added a non-functional runtime debug log (no user-facing behavior
changes).

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 12:22:41 +00:00
Abhimanyu Yadav
a55b2e02dc feat(frontend): enhance CredentialsInput and CredentialRow components with variant support (#11753)
### Changes 🏗️

- Added a new `variant` prop to `CredentialsInput` component with
options "default" or "node"
- Implemented compact styling for the "node" variant in `CredentialRow`
component
- Modified layout and overflow handling for credential display in node
context
- Added conditional rendering of masked key display based on variant
- Passed the variant prop through the component hierarchy
- Applied the "node" variant to the `CredentialsField` component with
appropriate styling

Before

![Screenshot 2026-01-12 at
4.39.35 PM.png](https://app.graphite.com/user-attachments/assets/2b605b2d-7abf-4e8a-adc5-6a6e8b712ef7.png)

After

![Screenshot 2026-01-12 at
4.55.39 PM.png](https://app.graphite.com/user-attachments/assets/20bb1452-870a-4111-a246-c4e3a3b456ea.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified credential selection works correctly in node context
  - [x] Confirmed compact styling is applied properly in node variant
  - [x] Tested overflow handling for long credential names
  - [x] Verified both default and node variants display correctly

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Credential input and selection components now support multiple
configurable visual variants, enabling better text display handling,
optimized layouts, and improved visual consistency across different
application contexts and specific use cases.

* **Style**
* Credential field displays now feature enhanced text truncation and
overflow management for a more polished and consistent user interface
experience.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 12:22:20 +00:00
Abhimanyu Yadav
6b6648b290 feat(frontend): add Table component with TableField renderer for tabular data input (#11751)
### Changes 🏗️

- Added a new `Table` component for handling tabular data input
- Created supporting hooks and helper functions for the Table component
- Added Storybook stories to showcase different Table configurations
- Implemented a custom `TableField` renderer for JSON Schema forms
- Updated type display info to support the new "table" format
- Added schema matcher to detect and render table fields appropriately

![Screenshot 2026-01-12 at
11.29.04 AM.png](https://app.graphite.com/user-attachments/assets/71469d59-469f-4cb0-882b-a49791fe948d.png)

![Screenshot 2026-01-12 at
11.28.54 AM.png](https://app.graphite.com/user-attachments/assets/81193f32-0e16-435e-bb66-5d2aea98266a.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified Table component renders correctly with various
configurations
  - [x] Tested adding and removing rows in the Table
- [x] Confirmed data changes are properly tracked and reported via
onChange
  - [x] Verified TableField renderer works with JSON Schema forms
  - [x] Checked that table format is properly detected in the schema

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added a Table component for displaying and editing tabular data with
support for adding/deleting rows, read-only mode, and customizable
labels.
* Added support for rendering array fields as tables in form inputs with
configurable columns and values.

* **Tests**
* Added comprehensive Storybook stories demonstrating various Table
configurations and behaviors.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 10:32:14 +00:00
Abhimanyu Yadav
c0a9c0410b feat(frontend): add MultiSelectField component and improve node title cursor styling (#11744)
## Changes 🏗️

- Added a new `MultiSelectField` component for handling multiple boolean
selections in a dropdown format
- Implemented `useMultiSelectField` hook to manage the state and logic
of the multi-select component
- Added support for custom fields in `AnyOfField` by checking if the
option schema matches a custom field
- Added `isMultiSelectSchema` utility function to detect schemas
suitable for the multi-select component
- Added hover cursor styling to node headers to indicate text
editability

![Screenshot 2026-01-10 at
11.15.12 AM.png](https://app.graphite.com/user-attachments/assets/8254497b-604f-4ccc-a40b-eb8994c073b4.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified that multi-select fields render correctly in the UI
  - [x] Confirmed that selecting multiple options works as expected
  - [x] Tested that the node header shows the text cursor on hover
- [x] Verified that AnyOf fields correctly use custom field renderers
when applicable

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a multi-select field allowing selection of multiple options with
improved selection UI.
* AnyOf options can now resolve and render custom field types, improving
form composition when schemas map to custom controls.

* **Style**
  * Tooltip header cursor updated for clearer hover feedback.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 09:48:58 +00:00
Abhimanyu Yadav
17a77b02c7 fix(frontend): exclude schemas with enum from anyOf detection (#11743)
### Changes 🏗️

Fixed the `isAnyOfSchema` function in schema-utils.ts to exclude schemas
that have an `enum` property. This prevents incorrect schema processing
for enums that also have anyOf definitions. Added a console.log
statement in FormRenderer.tsx to help debug schema preprocessing.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified that forms with enum values render correctly
- [x] Confirmed that anyOf schemas are properly identified and processed
- [x] Tested with various schema combinations to ensure the fix doesn't
break existing functionality

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Bug Fixes
* Improved validation logic for form field schemas to correctly handle
edge cases when multiple constraint types are defined.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 09:48:47 +00:00
Zamil Majdy
701fce83ca fix(backend): add missing metadata attribute to mock nodes in SmartDecisionMaker tests (#11750)
This PR fixes failing SmartDecisionMaker tests by adding missing
`metadata` attribute to mock nodes.

### Changes 🏗️

Mock nodes in SmartDecisionMaker tests were missing the `metadata = {}`
attribute, which was introduced in commit 4a52b7eca for the
customized_name feature. This caused tests to fail with:

```
TypeError: expected string or bytes-like object, got 'Mock'
```

**Files fixed**:
- `backend/blocks/test/test_smart_decision_maker_dict.py`: Added
`metadata = {}` to mock nodes in all 3 tests
- `backend/blocks/test/test_smart_decision_maker_dynamic_fields.py`:
Added `metadata = {}` to mock nodes in all 8 tests

**Root cause**: The `_create_block_function_signature` method calls
`sink_node.metadata.get("customized_name")`, but mock nodes in tests
didn't have the metadata attribute initialized.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Run `poetry run pytest
backend/blocks/test/test_smart_decision_maker_dict.py -xvs` - 3 passed
- [x] Run `poetry run pytest
backend/blocks/test/test_smart_decision_maker_dynamic_fields.py -xvs` -
8 passed
  - [x] All tests pass successfully

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **Tests**
* Updated test infrastructure to enhance mock object configuration for
improved test reliability and consistency across test suites.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-11 17:00:36 -06:00
Zamil Majdy
78d89d0faf Merge branch 'master' of github.com:Significant-Gravitas/AutoGPT into dev 2026-01-11 13:09:23 -06:00
Zamil Majdy
f482eb668b hotfix(backend): resolve tool pin name mismatch in SmartDecisionMakerBlock (#11749)
## Root Cause

Execution a40bdb4a-964d-4684-94e8-b148eb6bcfc2 and all similar
executions have been failing since Nov 12, 2025 when tool pin routing
was refactored to use node IDs. The SmartDecisionMakerBlock was
double-sanitizing field names when emitting tool call outputs:

```python
# Original field name from link: "Max Keyword Difficulty"
original_field_name = field_mapping.get(clean_arg_name)  #  Retrieved correctly
sanitized_arg_name = self.cleanup(original_field_name)   #  Sanitized AGAIN!
emit_key = f"tools_^_{node_id}_~_{sanitized_arg_name}"   # Emits "max_keyword_difficulty"
```

But the parser expected original names from graph links:
```python
# Parser expects: "Max Keyword Difficulty" (from link.sink_name)
# Emit provides: "max_keyword_difficulty" (sanitized)
# Result: Mismatch → Tool never executes
```

### Changes 🏗️

**1. Fixed Emit Logic** (`smart_decision_maker.py` line 1135)
- Removed double sanitization: `sanitized_arg_name =
self.cleanup(original_field_name)`
- Now emits with original field names: `emit_key =
f"tools_^_{node_id}_~_{original_field_name}"`

**2. Made Agent Nodes Consistent** (`smart_decision_maker.py` lines
497-530)
- Added `field_mapping` to agent function signatures (was missing)
- Agent signatures now sanitize property keys for Anthropic API (like
block signatures)
- Stores field_mapping for use during emit

### Impact

**Fixes:**
-  All graphs with multi-word field names (e.g., "Max Keyword
Difficulty", "Minimum Volume")
-  All graphs with special characters in field names (e.g., "API-Key")
-  Both block nodes AND agent nodes now work consistently

**Unaffected:**
- Single-word lowercase field names (e.g., "keyword", "url") - these
were already working

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified parse_execution_output handles exact match correctly
  - [x] Verified emit uses original field names
  - [x] Verified field_mapping works for both block and agent nodes
- [x] Re-run execution a40bdb4a-964d-4684-94e8-b148eb6bcfc2 after
deployment to verify fix

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
(no changes)
- [x] `docker-compose.yml` is updated or already compatible with my
changes (no changes)
- [x] No configuration changes in this PR

### Test Plan

1. **Unit test validation** (completed):
- Field name cleanup: "Max Keyword Difficulty" →
"max_keyword_difficulty" 
   - Parse with exact match: Success 
   - Parse with mismatch: Returns None 

2. **Production validation** (to be done after deployment):
   - Re-run execution a40bdb4a-964d-4684-94e8-b148eb6bcfc2
- Verify AgentExecutor (node 767682f5-694f-4b2a-bf52-fbdcad6a4a4f)
executes successfully
   - Verify execution completes with high correctness score (not 0.20)
   - Monitor for any regressions in existing graphs

### Files Changed

- `backend/blocks/smart_decision_maker.py`: Remove double sanitization,
add agent field_mapping

### Related Issues

- Resolves execution failure a40bdb4a-964d-4684-94e8-b148eb6bcfc2
- Fixes bug introduced in commit 536e2a5ec (Nov 12, 2025)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved field name mapping consistency in the SmartDecisionMaker
block to ensure proper handling of field names throughout function
signatures and tool execution workflows.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-12 02:08:12 +07:00
Nicholas Tindle
4a52b7eca0 fix(backend): use customized block names in smart decision maker
The SmartDecisionMakerBlock now respects the customized_name field from
node metadata when generating tool function signatures for the LLM.

Previously, the block always used the static block.name from the block
class definition, ignoring any custom names users set in the builder UI.

Changes:
- _create_block_function_signature: Check sink_node.metadata for
  customized_name before falling back to block.name
- _create_agent_function_signature: Check sink_node.metadata for
  customized_name before falling back to sink_graph_meta.name
- Added 4 unit tests for the customized_name feature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 16:51:39 -07:00
Zamil Majdy
97847f59f7 feat(backend): add human-in-the-loop review system for blocks requiring approval (#11732)
## Summary
Introduces a comprehensive Human-In-The-Loop (HITL) review system that
allows any block to require human approval before execution. This
extends the existing HITL infrastructure to support automatic review
requests for potentially dangerous operations.

## 🚀 Key Features

### **Automatic HITL for Any Block**
- **Simple opt-in**: Set `self.requires_human_review = True` in any
block constructor
- **Safe mode integration**: Only activates when
`execution_context.safe_mode = True`
- **Seamless workflow**: Blocks pause execution → Human reviews via
existing UI → Execution continues or stops

### **Unified Review Infrastructure**
- **Shared HITLReviewHelper**: Clean, reusable helper class for all
review operations
- **Single API**: `handle_review_decision()` method with structured
return type
- **Type-safe**: Proper typing with non-nullable
`ReviewDecision.review_result`

### **Smart Graph Detection** 
- **Updated `has_human_in_the_loop`**: Now detects both dedicated HITL
blocks and blocks with `requires_human_review = True`
- **Frontend awareness**: UI can properly indicate graphs requiring
human intervention

## 🏗️ Implementation

### **Block Usage**
```python
class MyBlock(Block):
    def __init__(self):
        super().__init__(...)
        self.requires_human_review = True  # Enable automatic HITL
        
    async def run(self, input_data, **kwargs):
        # If we reach here, either safe mode is off OR human approved
        # No additional HITL code needed - handled automatically by base class
        yield "result", "Operation completed"
```

### **Review Workflow**
1. **Block execution starts** → Base class checks
`requires_human_review` flag
2. **Safe mode enabled** → Creates review entry, pauses execution 
3. **Human reviews** → Uses existing review UI to approve/reject
4. **Execution resumes** → Continues if approved, raises error if
rejected
5. **Safe mode disabled** → Executes normally without review

## 🔧 Technical Improvements

### **Code Quality Enhancements**
- **Better naming**: `risky_block` → `requires_human_review` (clearer
intent)
- **Type safety**: Non-nullable `ReviewDecision.review_result`
(eliminates Optional checks)
- **Exhaustive handling**: Proper error handling for unexpected review
statuses
- **Clean exception handling**: Removed redundant try-catch-log-reraise
patterns

### **Architecture Fixes**
- **Circular import resolution**: Fixed `ExecutionContext` import issues
breaking 444+ block tests
- **Early returns**: Cleaner control flow without nested conditionals
- **Defensive programming**: Handles edge cases with clear error
messages

## 📊 Changes Made

### **Core Files**
- **`Block.requires_human_review`**: New flag for marking blocks
requiring approval
- **`HITLReviewHelper`**: Shared helper class with clean, testable API
- **`HumanInTheLoopBlock`**: Refactored to use shared infrastructure
- **`Graph.has_human_in_the_loop`**: Updated to include review-requiring
blocks

### **Quality Improvements**
- **Type hints**: Proper typing throughout with runtime compatibility
- **Error handling**: Exhaustive status handling with descriptive errors
- **Code reduction**: -16 lines through removal of redundant exception
handling
- **Test compatibility**: All 444/445 block tests pass

##  Testing & Validation

- **All tests pass**: 444/445 block tests passing 
- **Type checking**: All pyright/mypy checks pass   
- **Formatting**: All linting and formatting checks pass 
- **Circular imports**: Resolved import issues that were breaking tests

- **Backward compatibility**: Existing HITL functionality unchanged 

## 🎯 Use Cases

This enables automatic human oversight for blocks performing:
- **File operations**: Deletion, modification, system access
- **External API calls**: Payments, data modifications, destructive
operations
- **System commands**: Shell execution, configuration changes
- **Data processing**: Sensitive data handling, compliance-required
operations

## 🔄 Migration Path

**Existing code**: No changes required - fully backward compatible
**New blocks**: Simply set `self.requires_human_review = True` to enable
automatic HITL
**Safe mode**: Controls whether review requests are created (production
vs development)

---

This creates a robust, type-safe foundation for human oversight in
automated workflows while maintaining the existing HITL user experience
and API compatibility.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Human-in-the-loop review support so executions can pause for human
review and resume based on decisions.

* **Improvements**
* Blocks can opt into requiring human review and will use reviewed input
when proceeding.
* Unified review decision flow with clearer approved/rejected outcomes
and messaging.
* Graph detection expanded to recognize nodes that require human review.

* **Chores**
  * Test config adjusted to avoid pytest plugin conflicts.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-09 21:14:37 +00:00
Zamil Majdy
22ca8955c5 fix(backend): library agent creation and version update improvements (#11731)
## Summary
Fixes library agent creation and version update logic to properly handle
both user-created and marketplace agents.

## Changes
- **Remove useGraphIsActiveVersion filter** from
`update_agent_version_in_library` to allow both manual and auto updates
- **Set useGraphIsActiveVersion correctly**:
- `False` for marketplace agents (require manual updates to avoid
breaking workflows)
- `True` for user-created agents (can safely auto-update since user
controls source)
- Update function documentation to reflect new behavior

## Problem Solved
- Marketplace agents can now be updated manually via API
- User-created agents maintain auto-update capability  
- Resolves Sentry error AUTOGPT-SERVER-722 about "Expected a record,
found none"
- Fixes store submission modal issues

## Test Plan
- [x] Verify marketplace agents are created with
`useGraphIsActiveVersion: False`
- [x] Verify user agents are created with `useGraphIsActiveVersion:
True`
- [x] Confirm `update_agent_version_in_library` works for both types
- [x] Test store submission flow works without modal issues

## Review Notes
This change ensures proper separation between user-controlled agents
(auto-update) and marketplace agents (manual update), while allowing the
API to service both use cases.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

* **New Features**
* Enhanced agent publishing workflow with improved version tracking and
change detection for marketplace updates

* **Bug Fixes**
  * Improved error handling when updating agent versions in the library
  * Better detection of unpublished changes before publishing agents

* **Improvements**
* Changes Summary field now supports longer descriptions (up to 500
characters) with multi-line editing capability

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-09 21:14:05 +00:00
Nicholas Tindle
43cbe2e011 feat!(blocks): Add Reddit OAuth2 integration and advanced Reddit blocks (#11623)
Replaces user/password Reddit credentials with OAuth2, adds
RedditOAuthHandler, and updates Reddit blocks to support OAuth2
authentication. Introduces new blocks for creating posts, fetching post
details, searching, editing posts, and retrieving subreddit info.
Updates test credentials and input handling to use OAuth2 tokens.

<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
Rebuild the reddit blocks to support oauth2 rather than requiring users
to provide their password and username.
This is done via a swap from script based to web based authentication on
the reddit side faciliatated by the approval of an oauth app by reddit
on the account `ntindle`
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Build a super agent
  - [x] Upload the super agent and a video of it working

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces full Reddit OAuth2 support and substantially expands Reddit
capabilities across the platform.
> 
> - Adds `RedditOAuthHandler` with token exchange, refresh, revoke;
registers handler in `integrations/oauth/__init__.py`
> - Refactors Reddit blocks to use `OAuth2Credentials` and `praw` via
refresh tokens; updates models (e.g., `post_id`, richer outputs) and
adds `strip_reddit_prefix`
> - New blocks: create/edit/delete posts, post/get/delete comments,
reply to comments, get post details, user posts (self/others), search,
inbox, subreddit info/rules/flairs, send messages
> - Updates default `settings.config.reddit_user_agent` and test
credentials; minor `.branchlet.json` addition
> - Docs: clarifies block error-handling with
`BlockInputError`/`BlockExecutionError` guidance
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
4f1f26c7e7. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

* **New Features**
* Added OAuth2-based authentication for Reddit integration, replacing
legacy credential methods
* Expanded Reddit capabilities with new blocks for creating posts,
retrieving post details, managing comments, accessing inbox, and
fetching user/subreddit information
* Enhanced data models to support richer Reddit interactions and
chainable workflows

* **Documentation**
* Updated error handling guidance to distinguish between validation
errors and runtime errors with improved exception patterns

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
2026-01-09 20:53:03 +00:00
Nicholas Tindle
a318832414 feat(docs): update dev from gitbook changes (#11740)
<!-- Clearly explain the need for these changes: -->
gitbook branch has changes that need synced to dev
### Changes 🏗️
Pull changes from gitbook into dev
<!-- Concisely describe all of the changes made in this pull request:
-->

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Migrates documentation to GitBook and removes the old MkDocs setup.
> 
> - Removes MkDocs configuration and infra: `docs/mkdocs.yml`,
`docs/netlify.toml`, `docs/overrides/main.html`,
`docs/requirements.txt`, and JS assets (`_javascript/mathjax.js`,
`_javascript/tablesort.js`)
> - Updates `docs/content/contribute/index.md` to describe GitBook
workflow (gitbook branch, editing, previews, and `SUMMARY.md`)
> - Adds GitBook navigation file `docs/platform/SUMMARY.md` and a new
platform overview page `docs/platform/what-is-autogpt-platform.md`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
e7e118b5a8. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Documentation**
* Updated contribution guide for new documentation platform and workflow
  * Added new platform overview and navigation documentation

* **Chores**
  * Removed MkDocs configuration and related dependencies
  * Removed deprecated JavaScript integrations and deployment overrides

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 19:22:05 +00:00
Swifty
843c487500 feat(backend): add prisma types stub generator for pyright compatibility (#11736)
Prisma's generated `types.py` file is 57,000+ lines with complex
recursive TypedDict definitions that exhaust Pyright's type inference
budget. This causes random type errors and makes the type checker
unreliable.

### Changes 🏗️

- Add `gen_prisma_types_stub.py` script that generates a lightweight
`.pyi` stub file
- The stub preserves safe types (Literal, TypeVar) while collapsing
complex TypedDicts to `dict[str, Any]`
- Integrate stub generation into all workflows that run `prisma
generate`:
  - `platform-backend-ci.yml`
  - `claude.yml`
  - `claude-dependabot.yml`
  - `copilot-setup-steps.yml`
  - `docker-compose.platform.yml`
  - `Dockerfile`
  - `Makefile` (migrate & reset-db targets)
  - `linter.py` (lint & format commands)
- Add `gen-prisma-stub` poetry script entry
- Fix two pre-existing type errors that were previously masked:
- `store/db.py`: Replace private type
`_StoreListingVersion_version_OrderByInput` with dict literal
  - `airtable/_webhook.py`: Add cast for `Serializable` type

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run `poetry run format` - passes with 0 errors (down from 57+)
  - [x] Run `poetry run lint` - passes with 0 errors
  - [x] Run `poetry run gen-prisma-stub` - generates stub successfully
- [x] Verify stub file is created at correct location with proper
content

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Chores**
* Added a lightweight Prisma type-stub generator and integrated it into
build, lint, CI/CD, and container workflows.
* Build, migration, formatting, and lint steps now generate these stubs
to improve type-checking performance and reduce overhead during builds
and deployments.
  * Exposed a project command to run stub generation manually.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-09 16:31:10 +01:00
Nicholas Tindle
47a3a5ef41 feat(backend,frontend): optional credentials flag for blocks at agent level (#11716)
This feature allows agent makers to mark credential fields as optional.
When credentials are not configured for an optional block, the block
will be skipped during execution rather than causing a validation error.

**Use case:** An agent with multiple notification channels (Discord,
Twilio, Slack) where the user only needs to configure one - unconfigured
channels are simply skipped.

### Changes 🏗️

#### Backend

**Data Model Changes:**
- `backend/data/graph.py`: Added `credentials_optional` property to
`Node` model that reads from node metadata
- `backend/data/execution.py`: Added `nodes_to_skip` field to
`GraphExecutionEntry` model to track nodes that should be skipped

**Validation Changes:**
- `backend/executor/utils.py`:
- Updated `_validate_node_input_credentials()` to return a tuple of
`(credential_errors, nodes_to_skip)`
- Nodes with `credentials_optional=True` and missing credentials are
added to `nodes_to_skip` instead of raising validation errors
- Updated `validate_graph_with_credentials()` to propagate
`nodes_to_skip` set
- Updated `validate_and_construct_node_execution_input()` to return
`nodes_to_skip`
- Updated `add_graph_execution()` to pass `nodes_to_skip` to execution
entry

**Execution Changes:**
- `backend/executor/manager.py`:
  - Added skip logic in `_on_graph_execution()` dispatch loop
- When a node is in `nodes_to_skip`, it is marked as `COMPLETED` without
execution
  - No outputs are produced, so downstream nodes won't trigger

#### Frontend

**Node Store:**
- `frontend/src/app/(platform)/build/stores/nodeStore.ts`:
- Added `credentials_optional` to node metadata serialization in
`convertCustomNodeToBackendNode()`
- Added `getCredentialsOptional()` and `setCredentialsOptional()` helper
methods

**Credential Field Component:**
-
`frontend/src/components/renderers/input-renderer/fields/CredentialField/CredentialField.tsx`:
  - Added "Optional - skip block if not configured" switch toggle
  - Switch controls the `credentials_optional` metadata flag
  - Placeholder text updates based on optional state

**Credential Field Hook:**
-
`frontend/src/components/renderers/input-renderer/fields/CredentialField/useCredentialField.ts`:
  - Added `disableAutoSelect` parameter
- When credentials are optional, auto-selection of credentials is
disabled

**Feature Flags:**
- `frontend/src/services/feature-flags/use-get-flag.ts`: Minor refactor
(condition ordering)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Build an agent using smart decision maker and down stream blocks
to test this

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces optional credentials across graph execution and UI,
allowing nodes to be skipped (no outputs, no downstream triggers) when
their credentials are not configured.
> 
> - Backend
> - Adds `Node.credentials_optional` (from node `metadata`) and computes
required credential fields in `Graph.credentials_input_schema` based on
usage.
> - Validates credentials with `_validate_node_input_credentials` →
returns `(errors, nodes_to_skip)`; plumbs `nodes_to_skip` through
`validate_graph_with_credentials`,
`_construct_starting_node_execution_input`,
`validate_and_construct_node_execution_input`, and `add_graph_execution`
into `GraphExecutionEntry`.
> - Executor: dispatch loop skips nodes in `nodes_to_skip` (marks
`COMPLETED`); `execute_node`/`on_node_execution` accept `nodes_to_skip`;
`SmartDecisionMakerBlock.run` filters tool functions whose
`_sink_node_id` is in `nodes_to_skip` and errors only if all tools are
filtered.
> - Models: `GraphExecutionEntry` gains `nodes_to_skip` field. Tests and
snapshots updated accordingly.
> 
> - Frontend
> - Builder: credential field uses `custom/credential_field` with an
"Optional – skip block if not configured" toggle; `nodeStore` persists
`credentials_optional` and history; UI hides optional toggle in run
dialogs.
> - Run dialogs: compute required credentials from
`credentials_input_schema.required`; allow selecting "None"; avoid
auto-select for optional; filter out incomplete creds before execute.
>   - Minor schema/UI wiring updates (`uiSchema`, form context flags).
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
5e01fd6a3e. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2026-01-09 14:11:35 +00:00
Ubbe
ec00aa951a fix(frontend): agent favorites layout (#11733)
## Changes 🏗️

<img width="800" height="744" alt="Screenshot 2026-01-09 at 16 07 08"
src="https://github.com/user-attachments/assets/034c97e2-18f3-441c-a13d-71f668ad672f"
/>

- Remove feature flag for agent favourites ( _keep it always visible_ )
- Fix the layout on the card so the ❤️ icon appears next to the `...`
menu
- Remove icons on toasts

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app locally and check the above


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Favorites now respond to the current search term and are available to
all users (no feature-flag).

* **UI/UX Improvements**
* Redesigned Favorites section with simplified header, inline agent
counts, updated spacing/dividers, and removal of skeleton placeholders.
  * Favorite button repositioned and visually simplified on agent cards.
* Toast visuals simplified by removing per-type icons and adjusting
close-button positioning.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-09 18:52:07 +07:00
Zamil Majdy
36fb1ea004 fix(platform): store submission validation and marketplace improvements (#11706)
## Summary

Major improvements to AutoGPT Platform store submission deletion,
creator detection, and marketplace functionality. This PR addresses
critical issues with submission management and significantly improves
performance.

### 🔧 **Store Submission Deletion Issues Fixed**

**Problems Solved**:
-  **Wrong deletion granularity**: Deleting entire `StoreListing` (all
versions) when users expected to delete individual submissions
-  **"Graph not found" errors**: Cascade deletion removing AgentGraphs
that were still referenced
-  **Multiple submissions deleted**: When removing one submission, all
submissions for that agent were removed
-  **Deletion of approved content**: Users could accidentally remove
live store content

**Solutions Implemented**:
-  **Granular deletion**: Now deletes individual `StoreListingVersion`
records instead of entire listings
-  **Protected approved content**: Prevents deletion of approved
submissions to keep store content safe
-  **Automatic cleanup**: Empty listings are automatically removed when
last version is deleted
-  **Simplified logic**: Reduced deletion function from 85 lines to 32
lines for better maintainability

### 🔧 **Creator Detection Performance Issues Fixed**

**Problems Solved**:
-  **Inefficient API calls**: Fetching ALL user submissions just to
check if they own one specific agent
-  **Complex logic**: Convoluted creator detection requiring multiple
database queries
-  **Performance impact**: Especially bad for non-creators who would
never need this data

**Solutions Implemented**:
-  **Added `owner_user_id` field**: Direct ownership reference in
`LibraryAgent` model
-  **Simple ownership check**: `owner_user_id === user.id` instead of
complex submission fetching
-  **90%+ performance improvement**: Massive reduction in unnecessary
API calls for non-creators
-  **Optimized data fetching**: Only fetch submissions when user is
creator AND has marketplace listing

### 🔧 **Original Store Submission Validation Issues (BUILDER-59F)**
Fixes "Agent not found for this user. User ID: ..., Agent ID: , Version:
0" errors:

- **Backend validation**: Added Pydantic validation for `agent_id`
(min_length=1) and `agent_version` (>0)
- **Frontend validation**: Pre-submission validation with user-friendly
error messages
- **Agent selection flow**: Fixed `agentId` not being set from
`selectedAgentId`
- **State management**: Prevented state reset conflicts clearing
selected agent

### 🔧 **Marketplace Display Improvements**
Enhanced version history and changelog display:

- Updated title from "Changelog" to "Version history"
- Added "Last updated X ago" with proper relative time formatting  
- Display version numbers as "Version X.0" format
- Replaced all hardcoded values with dynamic API data
- Improved text sizes and layout structure

### 📁 **Files Changed**

**Backend Changes**:
- `backend/api/features/store/db.py` - Simplified deletion logic, added
approval protection
- `backend/api/features/store/model.py` - Added `listing_id` field,
Pydantic validation
- `backend/api/features/library/model.py` - Added `owner_user_id` field
for efficient creator detection
- All test files - Updated with new required fields

**Frontend Changes**:
- `useMarketplaceUpdate.ts` - Optimized creator detection logic 
- `MainDashboardPage.tsx` - Added `listing_id` mapping for proper type
safety
- `useAgentTableRow.ts` - Updated deletion logic to use
`store_listing_version_id`
- `usePublishAgentModal.ts` - Fixed state reset conflicts
- Marketplace components - Enhanced version history display

###  **Benefits**

**Performance**:
- 🚀 **90%+ reduction** in unnecessary API calls for creator detection
- 🚀 **Instant ownership checks** (no database queries needed)
- 🚀 **Optimized submissions fetching** (only when needed)

**User Experience**: 
-  **Granular submission control** (delete individual versions, not
entire listings)
-  **Protected approved content** (prevents accidental store content
removal)
-  **Better error prevention** (no more "Graph not found" errors)
-  **Clear validation messages** (user-friendly error feedback)

**Code Quality**:
-  **Simplified deletion logic** (85 lines → 32 lines)
-  **Better type safety** (proper `listing_id` field usage)  
-  **Cleaner creator detection** (explicit ownership vs inferred)
-  **Automatic cleanup** (empty listings removed automatically)

### 🧪 **Testing**
- [x] Backend validation rejects empty agent_id and zero agent_version
- [x] Frontend TypeScript compilation passes
- [x] Store submission works from both creator dashboard and "become a
creator" flows
- [x] Granular submission deletion works correctly
- [x] Approved submissions are protected from deletion
- [x] Creator detection is fast and accurate
- [x] Marketplace displays version history correctly

**Breaking Changes**: None - All changes are additive and backwards
compatible.

Fixes critical submission deletion issues, improves performance
significantly, and enhances user experience across the platform.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
  * Agent ownership is now tracked and exposed across the platform.
* Store submissions and versions now include a required listing_id to
preserve listing linkage.

* **Bug Fixes**
* Prevent deletion of APPROVED submissions; remove empty listings after
deletions.
* Edits restricted to PENDING submissions with clearer invalid-operation
messages.

* **Improvements**
* Stronger publish validation and UX guards; deduplicated images and
modal open/reset refinements.
* Version history shows relative "Last updated" times and version
badges.

* **Tests**
* E2E tests updated to target pending-submission flows for edit/delete.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-08 19:11:38 +00:00
Abhimanyu Yadav
a81ac150da fix(frontend): add word wrapping to CodeRenderer and improve output actions visibility (#11724)
## Changes 🏗️
- Updated the `CodeRenderer` component to add `whitespace-pre-wrap` and
`break-words` CSS classes to the `<code>` element
- This enables proper wrapping of long code lines while preserving
whitespace formatting

Before


![image.png](https://app.graphite.com/user-attachments/assets/aca769cc-0f6f-4e25-8cdd-c491fcbf21bb.png)

After

![Screenshot 2026-01-08 at
3.02.53 PM.png](https://app.graphite.com/user-attachments/assets/99e23efa-be2a-441b-b0d6-50fa2a08cdb0.png)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified code with long lines wraps correctly
  - [x] Confirmed whitespace and indentation are preserved
  - [x] Tested code display in various viewport sizes

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Code blocks now preserve whitespace and wrap long lines for improved
readability.
* Output action controls are hidden when there is only a single output
item, reducing unnecessary UI elements.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-08 11:13:47 +00:00
Abhimanyu Yadav
49ee087496 feat(frontend): add new integration images for Webshare and WordPress (#11725)
### Changes 🏗️

Added two new integration icons to the frontend:
- `webshare_proxy.png` - Icon for WebShare Proxy integration
- `wordpress.png` - Icon for WordPress integration

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verified both icons display correctly in the integrations section
  - [x] Confirmed icons render properly at different screen sizes
  - [x] Checked that the icons maintain quality when scaled

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
2026-01-08 11:13:34 +00:00
Ubbe
fc25e008b3 feat(frontend): update library agent cards to use DS (#11720)
## Changes 🏗️

<img width="700" height="838" alt="Screenshot 2026-01-07 at 16 11 04"
src="https://github.com/user-attachments/assets/0b38d2e1-d4a8-4036-862c-b35c82c496c2"
/>

- Update the agent library cards to new designs
- Update page to use Design System components
- Allow to edit/delete/duplicate agents on the library list page
- Add missing actions on library agent detail page

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Marketplace info shown on agent cards and improved favoriting with
optimistic UI and feedback.
  * Delete agent and delete schedule flows with confirmation dialogs.

* **Refactor**
* New composable form system, modernized upload dialog, streamlined
search bar, and multiple library components converted to named exports
with layout tweaks.
  * New agent card menu and favorite button UI.

* **Chores**
  * Removed notification UI and dropped a drag-drop dependency.

* **Tests**
  * Increased timeouts and stabilized upload/pagination flows.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-08 18:28:27 +07:00
Ubbe
b0855e8cf2 feat(frontend): context menu right click new builder (#11703)
## Changes 🏗️

<img width="250" height="504" alt="Screenshot 2026-01-06 at 17 53 26"
src="https://github.com/user-attachments/assets/52013448-f49c-46b6-b86a-39f98270cbc3"
/>

<img width="300" height="544" alt="Screenshot 2026-01-06 at 17 53 29"
src="https://github.com/user-attachments/assets/e6334034-68e4-4346-9092-3774ab3e8445"
/>

On the **New Builder**:
- right-click on a node menu make it show the context menu
- use the same menu for right-click and when clicking on `...`

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a custom right-click context menu for nodes with Copy, Open
agent (when available), and Delete actions; browser default menu is
suppressed while preserving zoom/drag/wiring.
* Introduced reusable SecondaryMenu primitives for context and dropdown
menus.

* **Documentation**
* Added Storybook examples demonstrating the context menu and dropdown
menu usage.

* **Style**
* Updated menu styling and icons with improved consistency and dark-mode
support.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-08 17:35:49 +07:00
Abhimanyu Yadav
5e2146dd76 feat(frontend): add CustomSchemaField wrapper for dynamic form field routing
(#11722)

### Changes 🏗️

This PR introduces automatic UI schema generation for custom form
fields, eliminating manual field mapping.

#### 1. **generateUiSchemaForCustomFields Utility**
(`generate-ui-schema.ts`) - New File
   - Auto-generates `ui:field` settings for custom fields
   - Detects custom fields using `findCustomFieldId()` matcher
   - Handles nested objects and array items recursively
   - Merges with existing UI schema without overwriting

#### 2. **FormRenderer Integration** (`FormRenderer.tsx`)
   - Imports and uses `generateUiSchemaForCustomFields`
   - Creates merged UI schema with `useMemo`
   - Passes merged schema to Form component
   - Enables automatic custom field detection

#### 3. **Preprocessor Cleanup** (`input-schema-pre-processor.ts`)
   - Removed manual `$id` assignment for custom fields
   - Removed unused `findCustomFieldId` import
   - Simplified to focus only on type validation

### Why these changes?

- Custom fields now auto-detect without manual `ui:field` configuration
- Uses standard RJSF approach (UI schema) for field routing
- Centralized custom field detection logic improves maintainability

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify custom fields render correctly when present in schema
- [x] Verify standard fields continue to render with default SchemaField
- [x] Verify multiple instances of same custom field type have unique
IDs
  - [x] Test form submission with custom fields

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved custom field rendering in forms by optimizing the UI schema
generation process.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-08 08:47:52 +00:00
Abhimanyu Yadav
103a62c9da feat(frontend/builder): add filters to blocks menu (#11654)
### Changes 🏗️

This PR adds filtering functionality to the new blocks menu, allowing
users to filter search results by category and creator.

**New Components:**
- `BlockMenuFilters`: Main filter component displaying active filters
and filter chips
- `FilterSheet`: Slide-out panel for selecting filters with categories
and creators
- `BlockMenuSearchContent`: Refactored search results display component

**Features Added:**
- Filter by categories: Blocks, Integrations, Marketplace agents, My
agents
- Filter by creator: Shows all available creators from search results
- Category counts: Display number of results per category
- Interactive filter chips with animations (using framer-motion)
- Hover states showing result counts on filter chips
- "All filters" sheet with apply/clear functionality

**State Management:**
- Extended `blockMenuStore` with filter state management
- Added `filters`, `creators`, `creators_list`, and `categoryCounts` to
store
- Integrated filters with search API (`filter` and `by_creator`
parameters)

**Refactoring:**
- Moved search logic from `BlockMenuSearch` to `BlockMenuSearchContent`
- Renamed `useBlockMenuSearch` to `useBlockMenuSearchContent`
- Moved helper functions to `BlockMenuSearchContent` directory

**API Changes:**
- Updated `custom-mutator.ts` to properly handle query parameter
encoding


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Search for blocks and verify filter chips appear
- [x] Click "All filters" and verify filter sheet opens with categories
- [x] Select/deselect category filters and verify results update
accordingly
  - [x] Filter by creator and verify only blocks from that creator show
  - [x] Clear all filters and verify reset to default state
  - [x] Verify filter counts display correctly
  - [x] Test filter chip hover animations
2026-01-08 08:02:21 +00:00
Bentlybro
fc8434fb30 Merge branch 'master' into dev 2026-01-07 12:02:15 +00:00
Ubbe
3ae08cd48e feat(frontend): use Google Drive Picker on new builder (#11702)
## Changes 🏗️

<img width="600" height="960" alt="Screenshot 2026-01-06 at 17 40 23"
src="https://github.com/user-attachments/assets/61085ec5-a367-45c7-acaa-e3fc0f0af647"
/>

- So when using Google Blocks on the new builder, it shows Google Drive
Picket 🏁

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
  - [x] Run app locally and test the above


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a Google Drive picker field and widget for forms with an
always-visible remove button and improved single/multi selection
handling.

* **Bug Fixes**
* Better validation and normalization of selected files and consolidated
error messaging.
* Adjusted layout spacing around the picker and selected files for
clearer display.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-07 17:07:09 +07:00
Swifty
4db13837b9 Revert "extracted frontend changes out of the hackathon/copilot branch"
This reverts commit df87867625.
2026-01-07 09:27:25 +01:00
Swifty
df87867625 extracted frontend changes out of the hackathon/copilot branch 2026-01-07 09:25:10 +01:00
Abhimanyu Yadav
e503126170 feat(frontend): upgrade RJSF to v6 and implement new FormRenderer system
(#11677)

Fixes #11686

### Changes 🏗️

This PR upgrades the React JSON Schema Form (RJSF) library from v5 to v6
and introduces a complete rewrite of the form rendering system with
improved architecture and new features.

#### Core Library Updates
- Upgraded `@rjsf/core` from 5.24.13 to 6.1.2
- Upgraded `@rjsf/utils` from 5.24.13 to 6.1.2
- Added `@radix-ui/react-slider` 1.3.6 for new slider components

#### New Form Renderer Architecture
- **Base Templates**: Created modular base templates for arrays,
objects, and standard fields
- **AnyOf Support**: Implemented `AnyOfField` component with type
selector for union types
- **Array Fields**: New `ArrayFieldTemplate`, `ArrayFieldItemTemplate`,
and `ArraySchemaField` with context provider
- **Object Fields**: Enhanced `ObjectFieldTemplate` with better support
for additional properties via `WrapIfAdditionalTemplate`
- **Field Templates**: New `TitleField`, `DescriptionField`, and
`FieldTemplate` with improved styling
- **Custom Widgets**: Implemented TextWidget, SelectWidget,
CheckboxWidget, FileWidget, DateWidget, TimeWidget, and DateTimeWidget
- **Button Components**: Custom AddButton, RemoveButton, and CopyButton
components

#### Node Handle System Refactor
- Split `NodeHandle` into `InputNodeHandle` and `OutputNodeHandle` for
better separation of concerns
- Refactored handle ID generation logic in `helpers.ts` with new
`generateHandleIdFromTitleId` function
- Improved handle connection detection using edge store
- Added support for nested output handles (objects within outputs)

#### Edge Store Improvements
- Added `removeEdgesByHandlePrefix` method for bulk edge removal
- Improved `isInputConnected` with handle ID cleanup
- Optimized `updateEdgeBeads` to only update when changes occur
- Better edge management with `applyEdgeChanges`

#### Node Store Enhancements
- Added `syncHardcodedValuesWithHandleIds` method to maintain
consistency between form data and handle connections
- Better handling of additional properties in objects
- Improved path parsing with `parseHandleIdToPath` and
`ensurePathExists`

#### Draft Recovery Improvements
- Added diff calculation with `calculateDraftDiff` to show what changed
- New `formatDiffSummary` to display changes in a readable format (e.g.,
"+2/-1 blocks, +3 connections")
- Better visual feedback for draft changes

#### UI/UX Enhancements
- Fixed node container width to 350px for consistency
- Improved field error display with inline error messages
- Better spacing and styling throughout forms
- Enhanced tooltip support for field descriptions
- Improved array item controls with better button placement
- Context-aware field sizing (small/large)

#### Output Handler Updates
- Recursive rendering of nested output properties
- Better type display with color coding
- Improved handle connections for complex output schemas

#### Migration & Cleanup
- Updated `RunInputDialog` to use new FormRenderer
- Updated `FormCreator` to use new FormRenderer
- Moved OAuth callback types to separate file
- Updated import paths from `input-renderer` to `InputRenderer`
- Removed unused console.log statements
- Added `type="button"` to buttons to prevent form submission

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Test form rendering with various field types (text, number,
boolean, arrays, objects)
  - [x] Test anyOf field type selector functionality
  - [x] Test array item addition/removal
  - [x] Test nested object fields with additional properties
  - [x] Test input/output node handle connections
  - [x] Test draft recovery with diff display
  - [x] Verify backward compatibility with existing agents
  - [x] Test field validation and error display
  - [x] Verify handle ID generation for complex schemas

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Improved form field rendering with enhanced support for optional
types, arrays, and nested objects.
* Enhanced draft recovery display showing detailed difference tracking
(added, removed, modified items).
  * Better OAuth popup callback handling with structured message types.

* **Bug Fixes**
  * Improved node handle ID normalization and synchronization.
  * Enhanced edge management for complex field changes.
  * Fixed styling consistency across form components.

* **Dependencies**
  * Updated React JSON Schema Form library to version 6.1.2.
  * Added Radix UI slider component support.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-01-07 05:06:34 +00:00
Zamil Majdy
7ee28197a3 docs(gitbook): sync documentation updates with dev branch (#11709)
## Summary

Sync GitBook documentation changes from the gitbook branch to dev. This
PR contains comprehensive documentation updates including new assets,
content restructuring, and infrastructure improvements.

## Changes 🏗️

### Documentation Updates
- **New GitBook Assets**: Added 9 new documentation images and
screenshots
  - Platform overview images (AGPT_Platform.png, Banner_image.png)
- Feature illustrations (Contribute.png, Integrations.png, hosted.jpg,
no-code.jpg, api-reference.jpg)
  - Screenshots and examples for better user guidance
- **Content Updates**: Enhanced README.md and SUMMARY.md with improved
structure and navigation
- **Visual Documentation**: Added comprehensive visual guides for
platform features

### Infrastructure 
- **Cloudflare Worker**: Added redirect handler for docs.agpt.co →
agpt.co/docs migration
  - Complete URL mapping for 71+ redirect patterns
  - Handles platform blocks restructuring and edge cases
  - Ready for deployment to Cloudflare Workers

### Merge Conflict Resolution
- **Clean merge from dev**: Successfully merged dev's major backend
restructuring (server/ → api/)
- **File resurrection fix**: Removed files that were accidentally
resurrected during merge conflict resolution
  - Cleaned up BuilderActionButton.tsx (deleted in dev)
  - Cleaned up old PreviewBanner.tsx location (moved in dev)
  - Synced pnpm-lock.yaml and layout.tsx with dev's current state

## Technical Details

This PR represents a careful synchronization that:
1. **Preserves all GitBook documentation work** while staying current
with dev
2. **Maintains clean diff**: Only documentation-related changes remain
after merge cleanup
3. **Resolves merge conflicts**: Handled major backend API restructuring
without breaking docs
4. **Infrastructure ready**: Cloudflare Worker ready for docs migration
deployment

## Files Changed
- `docs/`: GitBook documentation assets and content
- `autogpt_platform/cloudflare_worker.js`: Docs infrastructure for URL
redirects

## Validation
-  All TypeScript compilation errors resolved
-  Pre-commit hooks passing (Prettier, TypeCheck)
-  Only documentation changes remain in diff vs dev
-  Cloudflare Worker tested with comprehensive URL mapping
-  No non-documentation code changes after cleanup

## Deployment Notes
The Cloudflare Worker can be deployed via:
```bash
# Cloudflare Dashboard → Workers → Create → Paste code → Add route docs.agpt.co/*
```

This completes the GitBook synchronization and prepares for docs site
migration.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: bobby.gaffin <bobby.gaffin@agpt.co>
Co-authored-by: Bently <Github@bentlybro.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Lluis Agusti <hi@llu.lu>
2026-01-07 02:11:11 +00:00
Nicholas Tindle
818de26d24 fix(platform/blocks): XMLParserBlock list object error (#11517)
<!-- Clearly explain the need for these changes: -->

### Need for these changes 💡

The `XMLParserBlock` was susceptible to crashing with an
`AttributeError: 'List' object has no attribute 'add_text'` when
processing malformed XML inputs, such as documents with multiple root
elements or stray text outside the root. This PR introduces robust
validation to prevent these crashes and provide clear, actionable error
messages to users.

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

- Added a `_validate_tokens` static method to `XMLParserBlock` to
perform pre-parsing validation on the token stream. This method ensures
the XML input has a single root element and no text content outside of
it.
- Modified the `XMLParserBlock.run` method to call `_validate_tokens`
immediately after tokenization and before passing the tokens to
`gravitasml.Parser`.
- Introduced a new test case, `test_rejects_text_outside_root`, in
`test_blocks_dos_vulnerability.py` to verify that the `XMLParserBlock`
correctly raises a `ValueError` when encountering XML with text outside
the root element.
- Imported `Token` for type hinting in `xml_parser.py`.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Confirm that the `test_rejects_text_outside_root` test passes,
asserting that `ValueError` is raised for invalid XML.
  - [x] Confirm that other relevant XML parsing tests continue to pass.


---
Linear Issue:
[OPEN-2835](https://linear.app/autogpt/issue/OPEN-2835/blockunknownerror-raised-by-xmlparserblock-with-message-list-object)

<a
href="https://cursor.com/background-agent?bcId=bc-4495ea93-6836-412c-b2e3-0adb31113169"><picture><source
media="(prefers-color-scheme: dark)"
srcset="https://cursor.com/open-in-cursor-dark.svg"><source
media="(prefers-color-scheme: light)"
srcset="https://cursor.com/open-in-cursor-light.svg"><img alt="Open in
Cursor"
src="https://cursor.com/open-in-cursor.svg"></picture></a>&nbsp;<a
href="https://cursor.com/agents?id=bc-4495ea93-6836-412c-b2e3-0adb31113169"><picture><source
media="(prefers-color-scheme: dark)"
srcset="https://cursor.com/open-in-web-dark.svg"><source
media="(prefers-color-scheme: light)"
srcset="https://cursor.com/open-in-web-light.svg"><img alt="Open in Web"
src="https://cursor.com/open-in-web.svg"></picture></a>


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Strengthens XML parsing robustness and error clarity.
> 
> - Adds `_validate_tokens` in `XMLParserBlock` to ensure a single root
element, balanced tags, and no text outside the root before parsing
> - Updates `run` to `list(tokenize(...))` and validate tokens prior to
`Parser.parse()`; maintains 10MB input size guard
> - Introduces `test_rejects_text_outside_root` asserting a readable
`ValueError` for trailing text
> - Bumps `gravitasml` to `0.1.4` in `pyproject.toml` and lockfile
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
22cc5149c5. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Improved XML parsing validation with stricter enforcement of
single-root elements and prevention of trailing text, providing clearer
error messages for invalid XML input.

* **Tests**
* Added test coverage for XML parser validation of invalid root text
scenarios.

* **Chores**
  * Updated GravitasML dependency to latest compatible version.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2026-01-06 20:02:53 +00:00
Nicholas Tindle
cb08def96c feat(blocks): Add Google Docs integration blocks (#11608)
Introduces a new module with blocks for Google Docs operations,
including reading, creating, appending, inserting, formatting,
exporting, sharing, and managing public access for Google Docs. Updates
dependencies in pyproject.toml and poetry.lock to support these
features.



https://github.com/user-attachments/assets/3597366b-a9eb-4f8e-8a0a-5a0bc8ebc09b



<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
Adds lots of basic docs tools + a dependency to use them with markdown

Block | Description | Key Features
-- | -- | --
Read & Create |   |  
GoogleDocsReadBlock | Read content from a Google Doc | Returns text
content, title, revision ID
GoogleDocsCreateBlock | Create a new Google Doc | Title, optional
initial content
GoogleDocsGetMetadataBlock | Get document metadata | Title, revision ID,
locale, suggested modes
GoogleDocsGetStructureBlock | Get document structure with indexes | Flat
segments or detailed hierarchy; shows start/end indexes
Plain Text Operations |   |  
GoogleDocsAppendPlainTextBlock | Append plain text to end | No
formatting applied
GoogleDocsInsertPlainTextBlock | Insert plain text at position |
Requires index; no formatting
GoogleDocsFindReplacePlainTextBlock | Find and replace plain text |
Case-sensitive option; no formatting on replacement
Markdown Operations | (ideal for LLM/AI output) |  
GoogleDocsAppendMarkdownBlock | Append Markdown to end | Full formatting
via gravitas-md2gdocs
GoogleDocsInsertMarkdownAtBlock | Insert Markdown at position | Requires
index
GoogleDocsReplaceAllWithMarkdownBlock | Replace entire doc with Markdown
| Clears and rewrites
GoogleDocsReplaceRangeWithMarkdownBlock | Replace index range with
Markdown | Requires start/end index
GoogleDocsReplaceContentWithMarkdownBlock | Find text and replace with
Markdown | Text-based search; great for templates
Structural Operations |   |  
GoogleDocsInsertTableBlock | Insert a table | Rows/columns OR content
array; optional Markdown in cells
GoogleDocsInsertPageBreakBlock | Insert a page break | Position index (0
= end)
GoogleDocsDeleteContentBlock | Delete content range | Requires start/end
index
GoogleDocsFormatTextBlock | Apply formatting to text range | Bold,
italic, underline, font size/color, etc.
Export & Sharing |   |  
GoogleDocsExportBlock | Export to different formats | PDF, DOCX, TXT,
HTML, RTF, ODT, EPUB
GoogleDocsShareBlock | Share with specific users | Reader, commenter,
writer, owner roles
GoogleDocsSetPublicAccessBlock | Set public access level | Private,
anyone with link (view/comment/edit)


<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Build, run, verify, and upload a block super test
- [x] [Google Docs Super
Agent_v8.json](https://github.com/user-attachments/files/24134215/Google.Docs.Super.Agent_v8.json)
works


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Updated backend dependencies.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Adds end-to-end Google Docs capabilities under
`backend/blocks/google/docs.py`, including rich Markdown support.
> 
> - New blocks: read/create docs; plain-text
`append`/`insert`/`find_replace`/`delete`; text `format`;
`insert_table`; `insert_page_break`; `get_metadata`; `get_structure`
> - Markdown-powered blocks (via `gravitas_md2gdocs.to_requests`):
`append_markdown`, `insert_markdown_at`, `replace_all_with_markdown`,
`replace_range_with_markdown`, `replace_content_with_markdown`
> - Export and sharing: `export` (PDF/DOCX/TXT/HTML/RTF/ODT/EPUB),
`share` (user roles), `set_public_access`
> - Dependency updates: add `gravitas-md2gdocs` to `pyproject.toml` and
update `poetry.lock`
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
73512a95b2. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
2026-01-05 18:36:56 +00:00
Krzysztof Czerwinski
ac2daee5f8 feat(backend): Add GPT-5.2 and update default models (#11652)
### Changes 🏗️

- Add OpenAI `GPT-5.2` with metadata&cost
- Add const `DEFAULT_LLM_MODEL` (set to GPT-5.2) and use it instead of
hardcoded model across llm blocks and tests

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] GPT-5.2 is set as default and works on llm blocks
2026-01-05 16:13:35 +00:00
lif
266e0d79d4 fix(blocks): add YouTube Shorts URL support (#11659)
## Summary
Added support for parsing YouTube Shorts URLs (`youtube.com/shorts/...`)
in the TranscribeYoutubeVideoBlock to extract video IDs correctly.

## Changes
- Modified `_extract_video_id` method in `youtube.py` to handle Shorts
URL format
- Added test cases for YouTube Shorts URL extraction

## Related Issue
Fixes #11500

## Test Plan
- [x] Added unit tests for YouTube Shorts URL extraction
- [x] Verified existing YouTube URL formats still work
- [x] CI should pass all existing tests

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
2026-01-05 16:11:45 +00:00
lif
01f443190e fix(frontend): allow empty values in number inputs and fix AnyOfField toggle (#11661)
<!-- ⚠️ Reminder: Think about your Changeset/Docs changes! -->
<!-- If you are introducing new blocks or features, document them for
users. -->
<!-- Reference:
https://github.com/Significant-Gravitas/AutoGPT/blob/dev/CONTRIBUTING.md
-->

## Summary

This PR fixes two related issues with number/integer inputs in the
frontend:

1. **HTMLType typo fix**: INTEGER input type was incorrectly mapped to
`htmlType: 'account'` (which is not a valid HTML input type) instead of
`htmlType: 'number'`.

2. **AnyOfField toggle fix**: When a user cleared a number input field,
the input would disappear because `useAnyOfField` checked for both
`null` AND `undefined` in `isEnabled`. This PR changes it to only check
for explicit `null` (set by toggle off), allowing `undefined` (empty
input) to keep the field visible.

### Root cause analysis

When a user cleared a number input:
1. `handleChange` returned `undefined` (because `v === "" ? undefined :
Number(v)`)
2. In `useAnyOfField`, `isEnabled = formData !== null && formData !==
undefined` became `false`
3. The input field disappeared

### Fix

Changed `useAnyOfField.tsx` line 67:
```diff
- const isEnabled = formData !== null && formData !== undefined;
+ const isEnabled = formData !== null;
```

This way:
- When toggle is OFF → `formData` is `null` → `isEnabled` is `false`
(input hidden) ✓
- When toggle is ON but input is cleared → `formData` is `undefined` →
`isEnabled` is `true` (input visible) ✓

## Test plan

- [x] Verified INTEGER inputs now render correctly with `type="number"`
- [x] Verified clearing a number input keeps the field visible
- [x] Verified toggling the nullable switch still works correctly

Fixes #11594

🤖 AI-assisted development disclaimer: This PR was developed with
assistance from Claude Code.

---------

Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2026-01-05 16:10:47 +00:00
Ubbe
bdba0033de refactor(frontend): move NodeInput files (#11695)
## Changes 🏗️

Move the `<NodeInput />` component next to the old builder code where it
is used.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run app locally and click around, E2E is fine
2026-01-05 10:29:12 +00:00
Abhimanyu Yadav
b87c64ce38 feat(frontend): Add delete key bindings to ReactFlow editor
(#11693)

Issues fixed by this PR
- https://github.com/Significant-Gravitas/AutoGPT/issues/11688
- https://github.com/Significant-Gravitas/AutoGPT/issues/11687

### **Changes 🏗️**

Added keyboard delete functionality to the ReactFlow editor by enabling
the `deleteKeyCode` prop with both "Backspace" and "Delete" keys. This
allows users to delete selected nodes and edges using standard keyboard
shortcuts, improving the editing experience.

**Modified:**

- `Flow.tsx`: Added `deleteKeyCode={["Backspace", "Delete"]}` prop to
the ReactFlow component to enable deletion of selected elements via
keyboard

### **Checklist 📋**

#### **For code changes:**

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Select a node in the flow editor and press Delete key to confirm
it deletes
- [x] Select a node in the flow editor and press Backspace key to
confirm it deletes
    - [x] Verify deletion works for multiple selected elements
2026-01-05 10:28:57 +00:00
Ubbe
003affca43 refactor(frontend): fix new builder buttons (#11696)
## Changes 🏗️

<img width="800" height="964" alt="Screenshot 2026-01-05 at 15 26 21"
src="https://github.com/user-attachments/assets/f8c7fc47-894a-4db2-b2f1-62b4d70e8453"
/>

- Adjust the new builder to use the Design System components
- Re-structure imports to match formatting rules
- Small improvement on `use-get-flag`
- Move file which is the main hook

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and check the new buttons look good
2026-01-05 09:09:47 +00:00
Abhimanyu Yadav
290d0d9a9b feat(frontend): add auto-save Draft Recovery feature with IndexedDB persistence
(#11658)

## Summary
Implements an auto-save draft recovery system that persists unsaved flow
builder state across browser sessions, tab closures, and refreshes. When
users return to a flow with unsaved changes, they can choose to restore
or discard the draft via an intuitive recovery popup.



https://github.com/user-attachments/assets/0f77173b-7834-48d2-b7aa-73c6cd2eaff6



## Changes 🏗️

### Core Features
- **Draft Recovery Popup** (`DraftRecoveryPopup.tsx`)
  - Displays amber-themed notification with unsaved changes metadata
  - Shows node count, edge count, and relative time since last save
  - Provides restore and discard actions with tooltips
  - Auto-dismisses on click outside or ESC key

- **Auto-Save System** (`useDraftManager.ts`)
  - Automatically saves draft state every 15 seconds
  - Saves on browser tab close/refresh via `beforeunload`
  - Tracks nodes, edges, graph schemas, node counter, and flow version
  - Smart dirty checking - only saves when actual changes detected
  - Cleans up expired drafts (24-hour TTL)

- **IndexedDB Persistence** (`db.ts`, `draft-service.ts`)
  - Uses Dexie library for reliable client-side storage
- Handles both existing flows (by flowID) and new flows (via temp
session IDs)
- Compares draft state with current state to determine if recovery
needed
  - Automatically clears drafts after successful save

### Integration Changes
- **Flow Editor** (`Flow.tsx`)
  - Integrated `DraftRecoveryPopup` component
  - Passes `isInitialLoadComplete` state for proper timing

- **useFlow Hook** (`useFlow.ts`)
  - Added `isInitialLoadComplete` state to track when flow is ready
  - Ensures draft check happens after initial graph load
  - Resets state on flow/version changes

- **useCopyPaste Hook** (`useCopyPaste.ts`)
  - Refactored to manage keyboard event listeners internally
  - Simplified integration by removing external event handler setup

- **useSaveGraph Hook** (`useSaveGraph.ts`)
  - Clears draft after successful save (both create and update)
  - Removes temp flow ID from session storage on first save

### Dependencies
- Added `dexie@4.2.1` - Modern IndexedDB wrapper for reliable
client-side storage

## Technical Details

**Auto-Save Flow:**
1. User makes changes to nodes/edges
2. Change triggers 15-second debounced save
3. Draft saved to IndexedDB with timestamp
4. On save, current state compared with last saved state
5. Only saves if meaningful changes detected

**Recovery Flow:**
1. User loads flow/refreshes page
2. After initial load completes, check for existing draft
3. Compare draft with current state
4. If different and non-empty, show recovery popup
5. User chooses to restore or discard
6. Draft cleared after either action

**Session Management:**
- Existing flows: Use actual flowID for draft key

### Test Plan 🧪

- [x] Create a new flow with 3+ blocks and connections, wait 15+
seconds, then refresh the page - verify recovery popup appears with
correct counts and restoring works
- [x] Create a flow with blocks, refresh, then click "Discard" button on
recovery popup - verify popup disappears and draft is deleted
- [x] Add blocks to a flow, save successfully - verify draft is cleared
from IndexedDB (check DevTools > Application > IndexedDB)
- [x] Make changes to an existing flow, refresh page - verify recovery
popup shows and restoring preserves all changes correctly
- [x] Verify empty flows (0 nodes) don't trigger recovery popup or save
drafts
2025-12-31 14:49:53 +00:00
Abhimanyu Yadav
fba61c72ed feat(frontend): fix duplicate publish button and improve BuilderActionButton styling
(#11669)

Fixes duplicate "Publish to Marketplace" buttons in the builder by
adding a `showTrigger` prop to control modal trigger visibility.

<img width="296" height="99" alt="Screenshot 2025-12-23 at 8 18 58 AM"
src="https://github.com/user-attachments/assets/d5dbfba8-e854-4c0c-a6b7-da47133ec815"
/>


### Changes 🏗️

**BuilderActionButton.tsx**

- Removed borders on hover and active states for a cleaner visual
appearance
- Added `hover:border-none` and `active:border-none` to maintain
consistent styling during interactions

**PublishToMarketplace.tsx**

- Pass `showTrigger={false}` to `PublishAgentModal` to hide the default
trigger button
- This prevents duplicate buttons when a custom trigger is already
rendered

**PublishAgentModal.tsx**

- Added `showTrigger` prop (defaults to `true`) to conditionally render
the modal trigger
- Allows parent components to control whether the built-in trigger
button should be displayed
- Maintains backward compatibility with existing usage

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Verify only one "Publish to Marketplace" button appears in the
builder
- [x] Confirm button hover/active states display correctly without
border artifacts
- [x] Verify modal can still be triggered programmatically without the
trigger button
2025-12-31 09:46:12 +00:00
Nicholas Tindle
79d45a15d0 feat(platform): Deduplicate insufficient funds Discord + email notifications (#11672)
Add Redis-based deduplication for insufficient funds notifications (both
Discord alerts and user emails) when users run out of credits. This
prevents spamming users and the PRODUCT Discord channel with repeated
alerts for the same user+agent combination.

### Changes 🏗️

- **Redis-based deduplication** (`backend/executor/manager.py`):
- Add `INSUFFICIENT_FUNDS_NOTIFIED_PREFIX` constant for Redis key prefix
- Add `INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS` (30 days) as fallback
cleanup
- Implement deduplication in `_handle_insufficient_funds_notif` using
Redis `SET NX`
- Skip both email (`ZERO_BALANCE`) and Discord notifications for
duplicate alerts per user+agent
- Add `clear_insufficient_funds_notifications(user_id)` function to
remove all notification flags for a user

- **Clear flags on credit top-up** (`backend/data/credit.py`):
- Call `clear_insufficient_funds_notifications` in `_top_up_credits`
after successful auto-charge
- Call `clear_insufficient_funds_notifications` in `fulfill_checkout`
after successful manual top-up
- This allows users to receive notifications again if they run out of
funds in the future

- **Comprehensive test coverage**
(`backend/executor/manager_insufficient_funds_test.py`):
  - Test first-time notification sends both email and Discord alert
  - Test duplicate notifications are skipped for same user+agent
  - Test different agents for same user get separate alerts
  - Test clearing notifications removes all keys for a user
  - Test handling when no notification keys exist
- Test notifications still sent when Redis fails (graceful degradation)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] First insufficient funds alert sends both email and Discord
notification
  - [x] Duplicate alerts for same user+agent are skipped
  - [x] Different agents for same user each get their own notification
  - [x] Topping up credits clears notification flags
  - [x] Redis failure gracefully falls back to sending notifications
  - [x] 30-day TTL provides automatic cleanup as fallback
  - [x] Manually test this works with scheduled agents
 

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces Redis-backed deduplication for insufficient-funds alerts
and resets flags on successful credit additions.
> 
> - **Dedup insufficient-funds alerts** in `executor/manager.py` using
Redis `SET NX` with `INSUFFICIENT_FUNDS_NOTIFIED_PREFIX` and 30‑day TTL;
skips duplicate ZERO_BALANCE email + Discord alerts per
`user_id`+`graph_id`, with graceful fallback if Redis fails.
> - **Reset notification flags on credit increases** by adding
`clear_insufficient_funds_notifications(user_id)` and invoking it when
enabling/adding positive `GRANT`/`TOP_UP` transactions in
`data/credit.py`.
> - **Tests** (`executor/manager_insufficient_funds_test.py`):
first-time vs duplicate behavior, per-agent separation, clearing keys
(including no-key and Redis-error cases), and clearing on
`_add_transaction`/`_enable_transaction`.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
1a4413b3a1. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Claude <noreply@anthropic.com>
2025-12-30 18:10:30 +00:00
Ubbe
66f0d97ca2 fix(frontend): hide better chat link if not enabled (#11648)
## Changes 🏗️

- Make `<Navbar />` a client component so its rendering is more
predictable
- Remove the `useMemo()` for the chat link to prevent the flash...
- Make sure chat is added to the navbar links only after checking the
flag is enabled
- Improve logout with `useTransition`
- Simplify feature flags setup

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Ensures the `Chat` nav item is hidden when the feature flag is off
across desktop and mobile nav.
> 
> - Inline-filters `loggedInLinks` to skip `Chat` when `Flag.CHAT` is
false for both `NavbarLink` rendering and `MobileNavBar` menu items
> - Removes `useMemo`/`linksWithChat` helper; maps directly over
`loggedInLinks` and filters nulls in mobile, keeping icon mapping intact
> - Cleans up unused `useMemo` import
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
79c42d87b4. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
2025-12-30 13:21:53 +00:00
Ubbe
5894a8fcdf fix(frontend): use DS Dialog on old builder (#11643)
## Changes 🏗️

Use the Design System `<Dialog />` on the old builder, which supports
long content scrolling ( the current one does not, causing issues in
graphs with many run inputs )...

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added Enhanced Rendering toggle for improved output handling and
display (controlled via feature flag)

* **Improvements**
  * Refined dialog layouts and user interactions
* Enhanced copy-to-clipboard functionality with toast notifications upon
copying

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-12-30 20:22:57 +07:00
Ubbe
dff8efa35d fix(frontend): favico colour override issue (#11681)
## Changes 🏗️

Sometimes, on Dev, when navigating between pages, the Favico colour
would revert from Green 🟢 (Dev) to Purple 🟣(Default). That's because the
`/marketplace` page had custom code overriding it that I didn't notice
earlier...

I also made it use the Next.js metadata API, so it handles the favicon
correctly across navigations.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above
2025-12-30 20:22:32 +07:00
seer-by-sentry[bot]
e26822998f fix: Handle missing or null 'items' key in DataForSEO Related Keywords block (#10989)
### Changes 🏗️

- Modified the DataForSEO Related Keywords block to handle cases where
the 'items' key is missing or has a null value in the API response.
- Ensures that the code gracefully handles these scenarios by defaulting
to an empty list, preventing potential errors. Fixes
[AUTOGPT-SERVER-66D](https://sentry.io/organizations/significant-gravitas/issues/6902944636/).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] The DataForSEO API now returns an empty list when there are no
results, preventing the code from attempting to iterate on a null value.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Strengthens parsing of DataForSEO Labs response to avoid errors when
`items` is missing or null.
> 
> - In `backend/blocks/dataforseo/related_keywords.py` `run()`, sets
`items = first_result.get("items") or []` when `first_result` is a
`dict`, otherwise `[]`, ensuring safe iteration
> - Prevents exceptions and yields empty results when no items are
returned
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
cc465ddbf2. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Toran Bruce Richards <toran.richards@gmail.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-12-26 16:17:24 +00:00
Ubbe
4a7bc006a8 hotfix(frontend): chat should be disabled by default (#11639)
### Changes 🏗️

Chat should be disabled by default; otherwise, it flashes, and if Launch
Darkly fails to fail, it is dangerous.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally with Launch Darkly disabled and test the above
2025-12-18 19:04:13 +01:00
657 changed files with 42385 additions and 12098 deletions

37
.branchlet.json Normal file
View File

@@ -0,0 +1,37 @@
{
"worktreeCopyPatterns": [
".env*",
".vscode/**",
".auth/**",
".claude/**",
"autogpt_platform/.env*",
"autogpt_platform/backend/.env*",
"autogpt_platform/frontend/.env*",
"autogpt_platform/frontend/.auth/**",
"autogpt_platform/db/docker/.env*"
],
"worktreeCopyIgnores": [
"**/node_modules/**",
"**/dist/**",
"**/.git/**",
"**/Thumbs.db",
"**/.DS_Store",
"**/.next/**",
"**/__pycache__/**",
"**/.ruff_cache/**",
"**/.pytest_cache/**",
"**/*.pyc",
"**/playwright-report/**",
"**/logs/**",
"**/site/**"
],
"worktreePathTemplate": "$BASE_PATH.worktree",
"postCreateCmd": [
"cd autogpt_platform/autogpt_libs && poetry install",
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
"cd autogpt_platform/frontend && pnpm install",
"cd docs && pip install -r requirements.txt"
],
"terminalCommand": "code .",
"deleteBranchWithWorktree": false
}

View File

@@ -1,6 +1,9 @@
# Ignore everything by default, selectively add things to context
*
# Documentation (for embeddings/search)
!docs/
# Platform - Libs
!autogpt_platform/autogpt_libs/autogpt_libs/
!autogpt_platform/autogpt_libs/pyproject.toml
@@ -16,6 +19,7 @@
!autogpt_platform/backend/poetry.lock
!autogpt_platform/backend/README.md
!autogpt_platform/backend/.env
!autogpt_platform/backend/gen_prisma_types_stub.py
# Platform - Market
!autogpt_platform/market/market/

View File

@@ -74,7 +74,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -90,7 +90,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js

View File

@@ -72,7 +72,7 @@ jobs:
- name: Generate Prisma Client
working-directory: autogpt_platform/backend
run: poetry run prisma generate
run: poetry run prisma generate && poetry run gen-prisma-stub
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
- name: Set up Node.js
@@ -108,6 +108,16 @@ jobs:
# run: pnpm playwright install --with-deps chromium
# Docker setup for development environment
- name: Free up disk space
run: |
# Remove large unused tools to free disk space for Docker builds
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker system prune -af
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -134,7 +134,7 @@ jobs:
run: poetry install
- name: Generate Prisma Client
run: poetry run prisma generate
run: poetry run prisma generate && poetry run gen-prisma-stub
- id: supabase
name: Start Supabase
@@ -176,7 +176,7 @@ jobs:
}
- name: Run Database Migrations
run: poetry run prisma migrate dev --name updates
run: poetry run prisma migrate deploy
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}

View File

@@ -11,6 +11,7 @@ on:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event_name == 'merge_group' && format('merge-queue-{0}', github.ref) || format('{0}-{1}', github.ref, github.event.pull_request.number || github.sha) }}
@@ -151,6 +152,14 @@ jobs:
run: |
cp ../.env.default ../.env
- name: Copy backend .env and set OpenAI API key
run: |
cp ../backend/.env.default ../backend/.env
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
env:
# Used by E2E test data script to generate embeddings for approved store agents
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -226,13 +235,25 @@ jobs:
- name: Run Playwright tests
run: pnpm test:no-build
continue-on-error: false
- name: Upload Playwright artifacts
if: failure()
- name: Upload Playwright report
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-report
path: playwright-report
if-no-files-found: ignore
retention-days: 3
- name: Upload Playwright test results
if: always()
uses: actions/upload-artifact@v4
with:
name: playwright-test-results
path: test-results
if-no-files-found: ignore
retention-days: 3
- name: Print Final Docker Compose logs
if: always()

View File

@@ -6,12 +6,14 @@ start-core:
# Stop core services
stop-core:
docker compose stop deps
docker compose stop
reset-db:
docker compose stop db
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
# View logs for core services
logs-core:
@@ -33,6 +35,7 @@ init-env:
migrate:
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
cd backend && poetry run gen-prisma-stub
run-backend:
cd backend && poetry run app
@@ -58,4 +61,4 @@ help:
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " test-data - Run the test data creator"
@echo " load-store-agents - Load store agents from agents/ folder into test database"
@echo " load-store-agents - Load store agents from agents/ folder into test database"

View File

@@ -58,6 +58,13 @@ V0_API_KEY=
OPEN_ROUTER_API_KEY=
NVIDIA_API_KEY=
# Langfuse Prompt Management
# Used for managing the CoPilot system prompt externally
# Get credentials from https://cloud.langfuse.com or your self-hosted instance
LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
# OAuth Credentials
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
# e.g. http://localhost:3000/auth/integrations/oauth_callback

View File

@@ -18,3 +18,4 @@ load-tests/results/
load-tests/*.json
load-tests/*.log
load-tests/node_modules/*
migrations/*/rollback*.sql

View File

@@ -48,7 +48,8 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
RUN poetry run prisma generate
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
FROM debian:13-slim AS server_dependencies
@@ -99,6 +100,7 @@ COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migration
FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
COPY docs /app/docs
RUN poetry install --no-ansi --only-root
ENV PORT=8000

View File

@@ -70,7 +70,7 @@ class RunAgentRequest(BaseModel):
)
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
def _create_ephemeral_session(user_id: str) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)

View File

@@ -1,7 +1,6 @@
"""Configuration management for chat system."""
import os
from pathlib import Path
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
@@ -12,7 +11,11 @@ class ChatConfig(BaseSettings):
# OpenAI API Configuration
model: str = Field(
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
default="anthropic/claude-opus-4.5", description="Default model to use"
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
@@ -23,12 +26,6 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# System Prompt Configuration
system_prompt_path: str = Field(
default="prompts/chat_system.md",
description="Path to system prompt file relative to chat module",
)
# Streaming Configuration
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
@@ -41,6 +38,13 @@ class ChatConfig(BaseSettings):
default=3, description="Maximum number of agent schedules"
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
@@ -72,43 +76,11 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
def get_system_prompt(self, **template_vars) -> str:
"""Load and render the system prompt from file.
Args:
**template_vars: Variables to substitute in the template
Returns:
Rendered system prompt string
"""
# Get the path relative to this module
module_dir = Path(__file__).parent
prompt_path = module_dir / self.system_prompt_path
# Check for .j2 extension first (Jinja2 template)
j2_path = Path(str(prompt_path) + ".j2")
if j2_path.exists():
try:
from jinja2 import Template
template = Template(j2_path.read_text())
return template.render(**template_vars)
except ImportError:
# Jinja2 not installed, fall back to reading as plain text
return j2_path.read_text()
# Check for markdown file
if prompt_path.exists():
content = prompt_path.read_text()
# Simple variable substitution if Jinja2 is not available
for key, value in template_vars.items():
placeholder = f"{{{key}}}"
content = content.replace(placeholder, str(value))
return content
raise FileNotFoundError(f"System prompt file not found: {prompt_path}")
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",
"onboarding": "prompts/onboarding_system.md",
}
class Config:
"""Pydantic config."""

View File

@@ -0,0 +1,249 @@
"""Database operations for chat sessions."""
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any, cast
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
)
from backend.data.db import transaction
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": True},
)
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str,
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
userId=user_id,
credentials=SafeJson({}),
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(
data=data,
include={"Messages": True},
)
async def update_chat_session(
session_id: str,
credentials: dict[str, Any] | None = None,
successful_agent_runs: dict[str, Any] | None = None,
successful_agent_schedules: dict[str, Any] | None = None,
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
data["credentials"] = SafeJson(credentials)
if successful_agent_runs is not None:
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
if successful_agent_schedules is not None:
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
if total_prompt_tokens is not None:
data["totalPromptTokens"] = total_prompt_tokens
if total_completion_tokens is not None:
data["totalCompletionTokens"] = total_completion_tokens
if title is not None:
data["title"] = title
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": True},
)
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
session_id: str,
role: str,
sequence: int,
content: str | None = None,
name: str | None = None,
tool_call_id: str | None = None,
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
# Add optional string fields
if content is not None:
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
data["toolCalls"] = SafeJson(tool_calls)
if function_call is not None:
data["functionCall"] = SafeJson(function_call)
# Run message create and session timestamp update in parallel for lower latency
_, message = await asyncio.gather(
PrismaChatSession.prisma().update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return message
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
return []
created_messages = []
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
async def get_user_session_count(user_id: str) -> int:
"""Get the total number of chat sessions for a user."""
return await PrismaChatSession.prisma().count(where={"userId": user_id})
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
"""Delete a chat session and all its messages.
Args:
session_id: The session ID to delete.
user_id: If provided, validates that the session belongs to this user
before deletion. This prevents unauthorized deletion of other
users' sessions.
Returns:
True if deleted successfully, False otherwise.
"""
try:
# Build typed where clause with optional user_id validation
where_clause: ChatSessionWhereInput = {"id": session_id}
if user_id is not None:
where_clause["userId"] = user_id
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
if result == 0:
logger.warning(
f"No session deleted for {session_id} "
f"(user_id validation: {user_id is not None})"
)
return False
return True
except Exception as e:
logger.error(f"Failed to delete chat session {session_id}: {e}")
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count

View File

@@ -1,6 +1,9 @@
import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from weakref import WeakValueDictionary
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -16,17 +19,63 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.redis_client import get_redis_async
from backend.util.exceptions import RedisError
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
def _get_session_cache_key(session_id: str) -> str:
"""Get the Redis cache key for a chat session."""
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel):
role: str
content: str | None = None
@@ -45,7 +94,8 @@ class Usage(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str | None
user_id: str
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
@@ -55,10 +105,11 @@ class ChatSession(BaseModel):
successful_agent_schedules: dict[str, int] = {}
@staticmethod
def new(user_id: str | None) -> "ChatSession":
def new(user_id: str) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
@@ -66,6 +117,61 @@ class ChatSession(BaseModel):
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
prisma_session.successfulAgentRuns, default={}
)
successful_agent_schedules = _parse_json_field(
prisma_session.successfulAgentSchedules, default={}
)
# Calculate usage from token counts
usage = []
if prisma_session.totalPromptTokens or prisma_session.totalCompletionTokens:
usage.append(
Usage(
prompt_tokens=prisma_session.totalPromptTokens or 0,
completion_tokens=prisma_session.totalCompletionTokens or 0,
total_tokens=(prisma_session.totalPromptTokens or 0)
+ (prisma_session.totalCompletionTokens or 0),
)
)
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
updated_at=prisma_session.updatedAt,
successful_agent_runs=successful_agent_runs,
successful_agent_schedules=successful_agent_schedules,
)
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
for message in self.messages:
@@ -155,50 +261,337 @@ class ChatSession(BaseModel):
return messages
async def get_chat_session(
session_id: str,
user_id: str | None,
) -> ChatSession | None:
"""Get a chat session by ID."""
redis_key = f"chat:session:{session_id}"
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
logger.warning(f"Session {session_id} not found in Redis")
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
if session.user_id is not None and session.user_id != user_id:
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
) -> ChatSession | None:
"""Get a chat session by ID.
Checks Redis cache first, falls back to database if not found.
Caches database results back to Redis.
Args:
session_id: The session ID to fetch.
user_id: If provided, validates that the session belongs to this user.
If None, ownership is not validated (admin/system access).
"""
# Try cache first
try:
session = await _get_session_from_cache(session_id)
if session:
# Verify user ownership if user_id was provided for validation
if user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
return session
except RedisError:
logger.warning(f"Cache error for session {session_id}, trying database")
except Exception as e:
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
logger.warning(f"Session {session_id} not found in cache or database")
return None
# Verify user ownership if user_id was provided for validation
if user_id is not None and session.user_id != user_id:
logger.warning(
f"Session {session_id} user id mismatch: {session.user_id} != {user_id}"
)
return None
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
"""Update a chat session with the given messages."""
"""Update a chat session in both cache and database.
redis_key = f"chat:session:{session.session_id}"
Uses session-level locking to prevent race conditions when concurrent
operations (e.g., background title update and main stream handler)
attempt to upsert the same session simultaneously.
async_redis = await get_redis_async()
resp = await async_redis.setex(
redis_key, config.session_ttl, session.model_dump_json()
)
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
if not resp:
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {resp}"
async with lock:
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
)
db_error: Exception | None = None
# Save to database (primary storage)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
)
db_error = e
# Save to cache (best-effort, even if DB failed)
try:
await _cache_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
raise RedisError(
f"Failed to persist chat session {session.session_id} to Redis: {e}"
) from e
# If both failed, log cache error but raise DB error (more critical)
logger.warning(
f"Cache write also failed for session {session.session_id}: {e}"
)
# Propagate DB error after attempting cache (prevents data loss)
if db_error is not None:
raise DatabaseError(
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
Raises:
DatabaseError: If the database write fails. We fail fast to ensure
callers never receive a non-persisted session that only exists
in cache (which would be lost when the cache expires).
"""
session = ChatSession.new(user_id)
# Create in database first - fail fast if this fails
try:
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
except Exception as e:
logger.error(f"Failed to create session {session.session_id} in database: {e}")
raise DatabaseError(
f"Failed to create chat session {session.session_id} in database"
) from e
# Cache the session (best-effort optimization, DB is source of truth)
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
return session
async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count.
Returns:
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
"""Delete a chat session from both cache and database.
Args:
session_id: The session ID to delete.
user_id: If provided, validates that the session belongs to this user
before deletion. This prevents unauthorized deletion.
Returns:
True if deleted successfully, False otherwise.
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted:
return False
# Only invalidate cache and clean up lock after DB confirms deletion
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
return True
async def update_session_title(session_id: str, title: str) -> bool:
"""Update only the title of a chat session.
This is a lightweight operation that doesn't touch messages, avoiding
race conditions with concurrent message updates. Use this for background
title generation instead of upsert_chat_session.
Args:
session_id: The session ID to update.
title: The new title to set.
Returns:
True if updated successfully, False otherwise.
"""
try:
result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
# Invalidate cache so next fetch gets updated title
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False

View File

@@ -43,9 +43,9 @@ async def test_chatsession_serialization_deserialization():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage():
async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=None)
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
@@ -59,12 +59,61 @@ async def test_chatsession_redis_storage():
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_redis_storage_user_id_mismatch():
async def test_chatsession_redis_storage_user_id_mismatch(
setup_test_user, test_user_id
):
s = ChatSession.new(user_id="abc123")
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, None)
s2 = await get_chat_session(s.session_id, "different_user_id")
assert s2 is None
@pytest.mark.asyncio(loop_scope="session")
async def test_chatsession_db_storage(setup_test_user, test_user_id):
"""Test that messages are correctly saved to and loaded from DB (not cache)."""
from backend.data.redis_client import get_redis_async
# Create session with messages including assistant message
s = ChatSession.new(user_id=test_user_id)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
# Load from DB (cache was cleared)
s2 = await get_chat_session(
session_id=s.session_id,
user_id=s.user_id,
)
assert s2 is not None, "Session not found after loading from DB"
assert len(s2.messages) == len(
s.messages
), f"Message count mismatch: expected {len(s.messages)}, got {len(s2.messages)}"
# Verify all roles are present
roles = [m.role for m in s2.messages]
assert "user" in roles, f"User message missing. Roles found: {roles}"
assert "assistant" in roles, f"Assistant message missing. Roles found: {roles}"
assert "tool" in roles, f"Tool message missing. Roles found: {roles}"
# Verify message content
for orig, loaded in zip(s.messages, s2.messages):
assert orig.role == loaded.role, f"Role mismatch: {orig.role} != {loaded.role}"
assert (
orig.content == loaded.content
), f"Content mismatch for {orig.role}: {orig.content} != {loaded.content}"
if orig.tool_calls:
assert (
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)

View File

@@ -1,104 +0,0 @@
You are Otto, an AI Co-Pilot and Forward Deployed Engineer for AutoGPT, an AI Business Automation tool. Your mission is to help users quickly find and set up AutoGPT agents to solve their business problems.
Here are the functions available to you:
<functions>
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** - Run or schedule an agent (automatically handles setup)
</functions>
## HOW run_agent WORKS
The `run_agent` tool automatically handles the entire setup flow:
1. **First call** (no inputs) → Returns available inputs so user can decide what values to use
2. **Credentials check** → If missing, UI automatically prompts user to add them (you don't need to mention this)
3. **Execution** → Runs when you provide `inputs` OR set `use_defaults=true`
Parameters:
- `username_agent_slug` (required): Agent identifier like "creator/agent-name"
- `inputs`: Object with input values for the agent
- `use_defaults`: Set to `true` to run with default values (only after user confirms)
- `schedule_name` + `cron`: For scheduled execution
## WORKFLOW
1. **find_agent** - Search for agents that solve the user's problem
2. **run_agent** (first call, no inputs) - Get available inputs for the agent
3. **Ask user** what values they want to use OR if they want to use defaults
4. **run_agent** (second call) - Either with `inputs={...}` or `use_defaults=true`
## YOUR APPROACH
**Step 1: Understand the Problem**
- Ask maximum 1-2 targeted questions
- Focus on: What business problem are they solving?
- Move quickly to searching for solutions
**Step 2: Find Agents**
- Use `find_agent` immediately with relevant keywords
- Suggest the best option from search results
- Explain briefly how it solves their problem
**Step 3: Get Agent Inputs**
- Call `run_agent(username_agent_slug="creator/agent-name")` without inputs
- This returns the available inputs (required and optional)
- Present these to the user and ask what values they want
**Step 4: Run with User's Choice**
- If user provides values: `run_agent(username_agent_slug="...", inputs={...})`
- If user says "use defaults": `run_agent(username_agent_slug="...", use_defaults=true)`
- On success, share the agent link with the user
**For Scheduled Execution:**
- Add `schedule_name` and `cron` parameters
- Example: `run_agent(username_agent_slug="...", inputs={...}, schedule_name="Daily Report", cron="0 9 * * *")`
## FUNCTION CALL FORMAT
To call a function, use this exact format:
`<function_call>function_name(parameter="value")</function_call>`
Examples:
- `<function_call>find_agent(query="social media automation")</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name")</function_call>` (get inputs)
- `<function_call>run_agent(username_agent_slug="creator/agent-name", inputs={"topic": "AI news"})</function_call>`
- `<function_call>run_agent(username_agent_slug="creator/agent-name", use_defaults=true)</function_call>`
## KEY RULES
**What You DON'T Do:**
- Don't help with login (frontend handles this)
- Don't mention or explain credentials to the user (frontend handles this automatically)
- Don't run agents without first showing available inputs to the user
- Don't use `use_defaults=true` without user explicitly confirming
- Don't write responses longer than 3 sentences
**What You DO:**
- Always call run_agent first without inputs to see what's available
- Ask user what values they want OR if they want to use defaults
- Keep all responses to maximum 3 sentences
- Include the agent link in your response after successful execution
**Error Handling:**
- Authentication needed → "Please sign in via the interface"
- Credentials missing → The UI handles this automatically. Focus on asking the user about input values instead.
## RESPONSE STRUCTURE
Before responding, wrap your analysis in <thinking> tags to systematically plan your approach:
- Extract the key business problem or request from the user's message
- Determine what function call (if any) you need to make next
- Plan your response to stay under the 3-sentence maximum
Example interaction:
```
User: "Run the AI news agent for me"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news")</function_call>
[Tool returns: Agent accepts inputs - Required: topic. Optional: num_articles (default: 5)]
Otto: The AI News agent needs a topic. What topic would you like news about, or should I use the defaults?
User: "Use defaults"
Otto: <function_call>run_agent(username_agent_slug="autogpt/ai-news", use_defaults=true)</function_call>
```
KEEP ANSWERS TO 3 SENTENCES

View File

@@ -1,3 +1,10 @@
"""
Response models for Vercel AI SDK UI Stream Protocol.
This module implements the AI SDK UI Stream Protocol (v1) for streaming chat responses.
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
from enum import Enum
from typing import Any
@@ -5,97 +12,133 @@ from pydantic import BaseModel, Field
class ResponseType(str, Enum):
"""Types of streaming responses."""
"""Types of streaming responses following AI SDK protocol."""
TEXT_CHUNK = "text_chunk"
TEXT_ENDED = "text_ended"
TOOL_CALL = "tool_call"
TOOL_CALL_START = "tool_call_start"
TOOL_RESPONSE = "tool_response"
# Message lifecycle
START = "start"
FINISH = "finish"
# Text streaming
TEXT_START = "text-start"
TEXT_DELTA = "text-delta"
TEXT_END = "text-end"
# Tool interaction
TOOL_INPUT_START = "tool-input-start"
TOOL_INPUT_AVAILABLE = "tool-input-available"
TOOL_OUTPUT_AVAILABLE = "tool-output-available"
# Other
ERROR = "error"
USAGE = "usage"
STREAM_END = "stream_end"
class StreamBaseResponse(BaseModel):
"""Base response model for all streaming responses."""
type: ResponseType
timestamp: str | None = None
def to_sse(self) -> str:
"""Convert to SSE format."""
return f"data: {self.model_dump_json()}\n\n"
class StreamTextChunk(StreamBaseResponse):
"""Streaming text content from the assistant."""
type: ResponseType = ResponseType.TEXT_CHUNK
content: str = Field(..., description="Text content chunk")
# ========== Message Lifecycle ==========
class StreamToolCallStart(StreamBaseResponse):
class StreamStart(StreamBaseResponse):
"""Start of a new message."""
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
class StreamFinish(StreamBaseResponse):
"""End of message/stream."""
type: ResponseType = ResponseType.FINISH
# ========== Text Streaming ==========
class StreamTextStart(StreamBaseResponse):
"""Start of a text block."""
type: ResponseType = ResponseType.TEXT_START
id: str = Field(..., description="Text block ID")
class StreamTextDelta(StreamBaseResponse):
"""Streaming text content delta."""
type: ResponseType = ResponseType.TEXT_DELTA
id: str = Field(..., description="Text block ID")
delta: str = Field(..., description="Text content delta")
class StreamTextEnd(StreamBaseResponse):
"""End of a text block."""
type: ResponseType = ResponseType.TEXT_END
id: str = Field(..., description="Text block ID")
# ========== Tool Interaction ==========
class StreamToolInputStart(StreamBaseResponse):
"""Tool call started notification."""
type: ResponseType = ResponseType.TOOL_CALL_START
tool_name: str = Field(..., description="Name of the tool that was executed")
tool_id: str = Field(..., description="Unique tool call ID")
type: ResponseType = ResponseType.TOOL_INPUT_START
toolCallId: str = Field(..., description="Unique tool call ID")
toolName: str = Field(..., description="Name of the tool being called")
class StreamToolCall(StreamBaseResponse):
"""Tool invocation notification."""
class StreamToolInputAvailable(StreamBaseResponse):
"""Tool input is ready for execution."""
type: ResponseType = ResponseType.TOOL_CALL
tool_id: str = Field(..., description="Unique tool call ID")
tool_name: str = Field(..., description="Name of the tool being called")
arguments: dict[str, Any] = Field(
default_factory=dict, description="Tool arguments"
type: ResponseType = ResponseType.TOOL_INPUT_AVAILABLE
toolCallId: str = Field(..., description="Unique tool call ID")
toolName: str = Field(..., description="Name of the tool being called")
input: dict[str, Any] = Field(
default_factory=dict, description="Tool input arguments"
)
class StreamToolExecutionResult(StreamBaseResponse):
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
type: ResponseType = ResponseType.TOOL_RESPONSE
tool_id: str = Field(..., description="Tool call ID this responds to")
tool_name: str = Field(..., description="Name of the tool that was executed")
result: str | dict[str, Any] = Field(..., description="Tool execution result")
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
toolCallId: str = Field(..., description="Tool call ID this responds to")
output: str | dict[str, Any] = Field(..., description="Tool execution output")
# Additional fields for internal use (not part of AI SDK spec but useful)
toolName: str | None = Field(
default=None, description="Name of the tool that was executed"
)
success: bool = Field(
default=True, description="Whether the tool execution succeeded"
)
# ========== Other ==========
class StreamUsage(StreamBaseResponse):
"""Token usage statistics."""
type: ResponseType = ResponseType.USAGE
prompt_tokens: int
completion_tokens: int
total_tokens: int
promptTokens: int = Field(..., description="Number of prompt tokens")
completionTokens: int = Field(..., description="Number of completion tokens")
totalTokens: int = Field(..., description="Total number of tokens")
class StreamError(StreamBaseResponse):
"""Error response."""
type: ResponseType = ResponseType.ERROR
message: str = Field(..., description="Error message")
errorText: str = Field(..., description="Error message text")
code: str | None = Field(default=None, description="Error code")
details: dict[str, Any] | None = Field(
default=None, description="Additional error details"
)
class StreamTextEnded(StreamBaseResponse):
"""Text streaming completed marker."""
type: ResponseType = ResponseType.TEXT_ENDED
class StreamEnd(StreamBaseResponse):
"""End of stream marker."""
type: ResponseType = ResponseType.STREAM_END
summary: dict[str, Any] | None = Field(
default=None, description="Stream summary statistics"
)

View File

@@ -13,12 +13,25 @@ from backend.util.exceptions import NotFoundError
from . import service as chat_service
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
config = ChatConfig()
logger = logging.getLogger(__name__)
async def _validate_and_get_session(
session_id: str,
user_id: str | None,
) -> ChatSession:
"""Validate session exists and belongs to user."""
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found.")
return session
router = APIRouter(
tags=["chat"],
)
@@ -26,6 +39,14 @@ router = APIRouter(
# ========== Request/Response Models ==========
class StreamChatRequest(BaseModel):
"""Request model for streaming chat with optional context."""
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
class CreateSessionResponse(BaseModel):
"""Response model containing information on a newly created chat session."""
@@ -44,22 +65,77 @@ class SessionDetailResponse(BaseModel):
messages: list[dict]
class SessionSummaryResponse(BaseModel):
"""Response model for a session summary (without messages)."""
id: str
created_at: str
updated_at: str
title: str | None = None
class ListSessionsResponse(BaseModel):
"""Response model for listing chat sessions."""
sessions: list[SessionSummaryResponse]
total: int
# ========== Routes ==========
@router.get(
"/sessions",
dependencies=[Security(auth.requires_user)],
)
async def list_sessions(
user_id: Annotated[str, Security(auth.get_user_id)],
limit: int = Query(default=50, ge=1, le=100),
offset: int = Query(default=0, ge=0),
) -> ListSessionsResponse:
"""
List chat sessions for the authenticated user.
Returns a paginated list of chat sessions belonging to the current user,
ordered by most recently updated.
Args:
user_id: The authenticated user's ID.
limit: Maximum number of sessions to return (1-100).
offset: Number of sessions to skip for pagination.
Returns:
ListSessionsResponse: List of session summaries and total count.
"""
sessions, total_count = await get_user_sessions(user_id, limit, offset)
return ListSessionsResponse(
sessions=[
SessionSummaryResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
title=session.title,
)
for session in sessions
],
total=total_count,
)
@router.post(
"/sessions",
)
async def create_session(
user_id: Annotated[str | None, Depends(auth.get_user_id)],
user_id: Annotated[str, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
Initiates a new chat session for either an authenticated or anonymous user.
Initiates a new chat session for the authenticated user.
Args:
user_id: The optional authenticated user ID parsed from the JWT. If missing, creates an anonymous session.
user_id: The authenticated user ID parsed from the JWT (required).
Returns:
CreateSessionResponse: Details of the created session.
@@ -67,15 +143,15 @@ async def create_session(
"""
logger.info(
f"Creating session with user_id: "
f"...{user_id[-8:] if user_id and len(user_id) > 8 else '<redacted>'}"
f"...{user_id[-8:] if len(user_id) > 8 else '<redacted>'}"
)
session = await chat_service.create_chat_session(user_id)
session = await create_chat_session(user_id)
return CreateSessionResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
user_id=session.user_id or None,
user_id=session.user_id,
)
@@ -99,29 +175,88 @@ async def get_session(
SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.
"""
session = await chat_service.get_session(session_id, user_id)
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found")
messages = [message.model_dump() for message in session.messages]
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
)
return SessionDetailResponse(
id=session.session_id,
created_at=session.started_at.isoformat(),
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=[message.model_dump() for message in session.messages],
messages=messages,
)
@router.post(
"/sessions/{session_id}/stream",
)
async def stream_chat_post(
session_id: str,
request: StreamChatRequest,
user_id: str | None = Depends(auth.get_user_id),
):
"""
Stream chat responses for a session (POST with context support).
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
- Tool call UI elements (if invoked)
- Tool execution results
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
"""
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
yield chunk.to_sse()
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@router.get(
"/sessions/{session_id}/stream",
)
async def stream_chat(
async def stream_chat_get(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
Stream chat responses for a session.
Stream chat responses for a session (GET - legacy endpoint).
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
- Text fragments as they are generated
@@ -137,14 +272,7 @@ async def stream_chat(
StreamingResponse: SSE-formatted response chunks.
"""
# Validate session exists before starting the stream
# This prevents errors after the response has already started
session = await chat_service.get_session(session_id, user_id)
if not session:
raise NotFoundError(f"Session {session_id} not found. ")
if session.user_id is None and user_id is not None:
session = await chat_service.assign_user_to_session(session_id, user_id)
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
async for chunk in chat_service.stream_chat_completion(
@@ -155,6 +283,8 @@ async def stream_chat(
session=session, # Pass pre-fetched session to avoid double-fetch
):
yield chunk.to_sse()
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -163,6 +293,7 @@ async def stream_chat(
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
)
@@ -201,16 +332,28 @@ async def health_check() -> dict:
"""
Health check endpoint for the chat service.
Performs a full cycle test of session creation, assignment, and retrieval. Should always return healthy
Performs a full cycle test of session creation and retrieval. Should always return healthy
if the service and data layer are operational.
Returns:
dict: A status dictionary indicating health, service name, and API version.
"""
session = await chat_service.create_chat_session(None)
await chat_service.assign_user_to_session(session.session_id, "test_user")
await chat_service.get_session(session.session_id, "test_user")
from backend.data.user import get_or_create_user
# Ensure health check user exists (required for FK constraint)
health_check_user_id = "health-check-user"
await get_or_create_user(
{
"sub": health_check_user_id,
"email": "health-check@system.local",
"user_metadata": {"name": "Health Check User"},
}
)
# Create and retrieve session to verify full data layer
session = await create_chat_session(health_check_user_id)
await get_chat_session(session.session_id, health_check_user_id)
return {
"status": "healthy",

File diff suppressed because it is too large Load Diff

View File

@@ -4,18 +4,19 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import (
StreamEnd,
StreamError,
StreamTextChunk,
StreamToolExecutionResult,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion():
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
@@ -23,7 +24,7 @@ async def test_stream_chat_completion():
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await chat_service.create_chat_session()
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
@@ -34,9 +35,9 @@ async def test_stream_chat_completion():
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextChunk):
assistant_message += chunk.content
if isinstance(chunk, StreamEnd):
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
@@ -45,7 +46,7 @@ async def test_stream_chat_completion():
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls():
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
@@ -53,8 +54,8 @@ async def test_stream_chat_completion_with_tool_calls():
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await chat_service.create_chat_session()
session = await chat_service.upsert_chat_session(session)
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
@@ -68,14 +69,14 @@ async def test_stream_chat_completion_with_tool_calls():
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamEnd):
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolExecutionResult):
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await chat_service.get_session(session.session_id)
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"

View File

@@ -4,21 +4,40 @@ from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .find_agent import FindAgentTool
from .find_block import FindBlockTool
from .find_library_agent import FindLibraryAgentTool
from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolExecutionResult
from backend.api.features.chat.response_model import StreamToolOutputAvailable
# Initialize tool instances
find_agent_tool = FindAgentTool()
run_agent_tool = RunAgentTool()
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"find_agent": FindAgentTool(),
"find_block": FindBlockTool(),
"find_library_agent": FindLibraryAgentTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"agent_output": AgentOutputTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
}
# Export tools as OpenAI format
# Export individual tool instances for backwards compatibility
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
find_agent_tool.as_openai_tool(),
run_agent_tool.as_openai_tool(),
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
@@ -28,14 +47,9 @@ async def execute_tool(
user_id: str | None,
session: ChatSession,
tool_call_id: str,
) -> "StreamToolExecutionResult":
tool_map: dict[str, BaseTool] = {
"find_agent": find_agent_tool,
"run_agent": run_agent_tool,
}
if tool_name not in tool_map:
) -> "StreamToolOutputAvailable":
"""Execute a tool by name."""
tool = TOOL_REGISTRY.get(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
return await tool_map[tool_name].execute(
user_id, session, tool_call_id, **parameters
)
return await tool.execute(user_id, session, tool_call_id, **parameters)

View File

@@ -3,6 +3,7 @@ from datetime import UTC, datetime
from os import getenv
import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
@@ -17,7 +18,7 @@ from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
def make_session(user_id: str | None = None):
def make_session(user_id: str):
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
@@ -49,13 +50,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create a test graph with agent input -> agent output
@@ -172,13 +173,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for LLM tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles
)
)
# 2. Create test OpenAI credentials for the user
@@ -332,13 +333,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0]
await prisma.profile.create(
data={
"userId": user.id,
"username": username,
"name": f"Test User {username}",
"description": "Test user profile for Firecrawl tests",
"links": [], # Required field - empty array for test profiles
}
data=ProfileCreateInput(
userId=user.id,
username=username,
name=f"Test User {username}",
description="Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles
)
)
# NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -0,0 +1,119 @@
"""Tool for capturing user business understanding incrementally."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
logger = logging.getLogger(__name__)
class AddUnderstandingTool(BaseTool):
"""Tool for capturing user's business understanding incrementally."""
@property
def name(self) -> str:
return "add_understanding"
@property
def description(self) -> str:
return """Capture and store information about the user's business context,
workflows, pain points, and automation goals. Call this tool whenever the user
shares information about their business. Each call incrementally adds to the
existing understanding - you don't need to provide all fields at once.
Use this to build a comprehensive profile that helps recommend better agents
and automations for the user's specific needs."""
@property
def parameters(self) -> dict[str, Any]:
# Auto-generate from Pydantic model schema
schema = BusinessUnderstandingInput.model_json_schema()
properties = {}
for field_name, field_schema in schema.get("properties", {}).items():
prop: dict[str, Any] = {"description": field_schema.get("description", "")}
# Handle anyOf for Optional types
if "anyOf" in field_schema:
for option in field_schema["anyOf"]:
if option.get("type") != "null":
prop["type"] = option.get("type", "string")
if "items" in option:
prop["items"] = option["items"]
break
else:
prop["type"] = field_schema.get("type", "string")
if "items" in field_schema:
prop["items"] = field_schema["items"]
properties[field_name] = prop
return {"type": "object", "properties": properties, "required": []}
@property
def requires_auth(self) -> bool:
"""Requires authentication to store user-specific data."""
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""
Capture and store business understanding incrementally.
Each call merges new data with existing understanding:
- String fields are overwritten if provided
- List fields are appended (with deduplication)
"""
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required to save business understanding.",
session_id=session_id,
)
# Check if any data was provided
if not any(v is not None for v in kwargs.values()):
return ErrorResponse(
message="Please provide at least one field to update.",
session_id=session_id,
)
# Build input model from kwargs (only include fields defined in the model)
valid_fields = set(BusinessUnderstandingInput.model_fields.keys())
input_data = BusinessUnderstandingInput(
**{k: v for k, v in kwargs.items() if k in valid_fields}
)
# Track which fields were updated
updated_fields = [
k for k, v in kwargs.items() if k in valid_fields and v is not None
]
# Upsert with merge
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary (filter out empty values)
current_understanding = {
k: v
for k, v in understanding.model_dump(
exclude={"id", "user_id", "created_at", "updated_at"}
).items()
if v is not None and v != [] and v != ""
}
return UnderstandingUpdatedResponse(
message=f"Updated understanding with: {', '.join(updated_fields)}. "
"I now have a better picture of your business context.",
session_id=session_id,
updated_fields=updated_fields,
current_understanding=current_understanding,
)

View File

@@ -0,0 +1,446 @@
"""Tool for retrieving agent execution outputs from user's library."""
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .models import (
AgentOutputResponse,
ErrorResponse,
ExecutionOutputInfo,
NoResultsResponse,
ToolResponseBase,
)
from .utils import fetch_graph_from_store_slug
logger = logging.getLogger(__name__)
class AgentOutputInput(BaseModel):
"""Input parameters for the agent_output tool."""
agent_name: str = ""
library_agent_id: str = ""
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
@field_validator(
"agent_name",
"library_agent_id",
"store_slug",
"execution_id",
"run_time",
mode="before",
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
def parse_time_expression(
time_expr: str | None,
) -> tuple[datetime | None, datetime | None]:
"""
Parse time expression into datetime range (start, end).
Supports: "latest", "yesterday", "today", "last week", "last 7 days",
"last month", "last 30 days", ISO date "YYYY-MM-DD", ISO datetime.
"""
if not time_expr or time_expr.lower() == "latest":
return None, None
now = datetime.now(timezone.utc)
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
expr = time_expr.lower().strip()
# Relative time expressions lookup
relative_times: dict[str, tuple[datetime, datetime]] = {
"yesterday": (today_start - timedelta(days=1), today_start),
"today": (today_start, now),
"last week": (now - timedelta(days=7), now),
"last 7 days": (now - timedelta(days=7), now),
"last month": (now - timedelta(days=30), now),
"last 30 days": (now - timedelta(days=30), now),
}
if expr in relative_times:
return relative_times[expr]
# Try ISO date format (YYYY-MM-DD)
date_match = re.match(r"^(\d{4})-(\d{2})-(\d{2})$", expr)
if date_match:
try:
year, month, day = map(int, date_match.groups())
start = datetime(year, month, day, 0, 0, 0, tzinfo=timezone.utc)
return start, start + timedelta(days=1)
except ValueError:
# Invalid date components (e.g., month=13, day=32)
pass
# Try ISO datetime
try:
parsed = datetime.fromisoformat(expr.replace("Z", "+00:00"))
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed - timedelta(hours=1), parsed + timedelta(hours=1)
except ValueError:
return None, None
class AgentOutputTool(BaseTool):
"""Tool for retrieving execution outputs from user's library agents."""
@property
def name(self) -> str:
return "agent_output"
@property
def description(self) -> str:
return """Retrieve execution outputs from agents in the user's library.
Identify the agent using one of:
- agent_name: Fuzzy search in user's library
- library_agent_id: Exact library agent ID
- store_slug: Marketplace format 'username/agent-name'
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
"""
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_name": {
"type": "string",
"description": "Agent name to search for in user's library (fuzzy match)",
},
"library_agent_id": {
"type": "string",
"description": "Exact library agent ID",
},
"store_slug": {
"type": "string",
"description": "Marketplace identifier: 'username/agent-slug'",
},
"execution_id": {
"type": "string",
"description": "Specific execution ID to retrieve",
},
"run_time": {
"type": "string",
"description": (
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _resolve_agent(
self,
user_id: str,
agent_name: str | None,
library_agent_id: str | None,
store_slug: str | None,
) -> tuple[LibraryAgent | None, str | None]:
"""
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
return None, f"Library agent '{library_agent_id}' not found"
# Priority 2: Store slug (username/agent-name)
if store_slug and "/" in store_slug:
username, agent_slug = store_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_slug)
if not graph:
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
f"Agent '{store_slug}' is not in your library. "
"Add it first to see outputs.",
)
return agent, None
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
)
if not response.agents:
return (
None,
f"No agents matching '{agent_name}' found in your library",
)
# Return best match (first result from search)
return response.agents[0], None
except Exception as e:
logger.error(f"Error searching library agents: {e}")
return None, f"Error searching for agent: {e}"
return (
None,
"Please specify an agent name, library_agent_id, or store_slug",
)
async def _get_execution(
self,
user_id: str,
graph_id: str,
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
"""
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
)
if not executions:
return None, [], None # No error, just no executions
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
)
return full_execution, executions, None
def _build_response(
self,
agent: LibraryAgent,
execution: GraphExecution | None,
available_executions: list[GraphExecutionMeta],
session_id: str | None,
) -> AgentOutputResponse:
"""Build the response based on execution data."""
library_agent_link = f"/library/agents/{agent.id}"
if not execution:
return AgentOutputResponse(
message=f"No completed executions found for agent '{agent.name}'",
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
total_executions=0,
)
execution_info = ExecutionOutputInfo(
execution_id=execution.id,
status=execution.status.value,
started_at=execution.started_at,
ended_at=execution.ended_at,
outputs=dict(execution.outputs),
inputs_summary=execution.inputs if execution.inputs else None,
)
available_list = None
if len(available_executions) > 1:
available_list = [
{
"id": e.id,
"status": e.status.value,
"started_at": e.started_at.isoformat() if e.started_at else None,
}
for e in available_executions[:5]
]
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
message=message,
session_id=session_id,
agent_name=agent.name,
agent_id=agent.graph_id,
library_agent_id=agent.id,
library_agent_link=library_agent_link,
execution=execution_info,
available_executions=available_list,
total_executions=len(available_executions) if available_executions else 1,
)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the agent_output tool."""
session_id = session.session_id
# Parse and validate input
try:
input_data = AgentOutputInput(**kwargs)
except Exception as e:
logger.error(f"Invalid input: {e}")
return ErrorResponse(
message="Invalid input parameters",
error=str(e),
session_id=session_id,
)
# Ensure user_id is present (should be guaranteed by requires_auth)
if not user_id:
return ErrorResponse(
message="User authentication required",
session_id=session_id,
)
# Check if at least one identifier is provided
if not any(
[
input_data.agent_name,
input_data.library_agent_id,
input_data.store_slug,
input_data.execution_id,
]
):
return ErrorResponse(
message=(
"Please specify at least one of: agent_name, "
"library_agent_id, store_slug, or execution_id"
),
session_id=session_id,
)
# If only execution_id provided, we need to find the agent differently
if (
input_data.execution_id
and not input_data.agent_name
and not input_data.library_agent_id
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
)
if not execution:
return ErrorResponse(
message=f"Execution '{input_data.execution_id}' not found",
session_id=session_id,
)
# Find library agent by graph_id
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
return NoResultsResponse(
message=(
f"Execution found but agent not in your library. "
f"Graph ID: {execution.graph_id}"
),
session_id=session_id,
suggestions=["Add the agent to your library to see more details"],
)
return self._build_response(agent, execution, [], session_id)
# Resolve agent from identifiers
agent, error = await self._resolve_agent(
user_id=user_id,
agent_name=input_data.agent_name or None,
library_agent_id=input_data.library_agent_id or None,
store_slug=input_data.store_slug or None,
)
if error or not agent:
return NoResultsResponse(
message=error or "Agent not found",
session_id=session_id,
suggestions=[
"Check the agent name or ID",
"Make sure the agent is in your library",
],
)
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
)
if exec_error:
return ErrorResponse(
message=exec_error,
session_id=session_id,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -0,0 +1,151 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
import logging
from typing import Literal
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
AgentInfo,
AgentsFoundResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
SearchSource = Literal["marketplace", "library"]
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
Args:
query: Search query string
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
if not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
if source == "library" and not user_id:
return ErrorResponse(
message="User authentication required to search library",
session_id=session_id,
)
agents: list[AgentInfo] = []
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
id=f"{agent.creator}/{agent.slug}",
name=agent.agent_name,
description=agent.description or "",
source="marketplace",
in_library=False,
creator=agent.creator,
category="general",
rating=agent.rating,
runs=agent.runs,
is_featured=False,
)
)
else: # library
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
except DatabaseError as e:
logger.error(f"Error searching {source}: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to search {source}. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
suggestions = (
[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
if source == "marketplace"
else [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
)
no_results_msg = (
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
if source == "marketplace"
else f"No agents matching '{query}' found in your library."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
)
return AgentsFoundResponse(
message=message,
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
)

View File

@@ -6,7 +6,7 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolExecutionResult
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
@@ -53,7 +53,7 @@ class BaseTool:
session: ChatSession,
tool_call_id: str,
**kwargs,
) -> StreamToolExecutionResult:
) -> StreamToolOutputAvailable:
"""Execute the tool with authentication check.
Args:
@@ -69,10 +69,10 @@ class BaseTool:
logger.error(
f"Attempted tool call for {self.name} but user not authenticated"
)
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=NeedLoginResponse(
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=NeedLoginResponse(
message=f"Please sign in to use {self.name}",
session_id=session.session_id,
).model_dump_json(),
@@ -81,17 +81,17 @@ class BaseTool:
try:
result = await self._execute(user_id, session, **kwargs)
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=result.model_dump_json(),
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=result.model_dump_json(),
)
except Exception as e:
logger.error(f"Error in {self.name}: {e}", exc_info=True)
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
result=ErrorResponse(
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=ErrorResponse(
message=f"An error occurred while executing {self.name}",
error=str(e),
session_id=session.session_id,

View File

@@ -1,26 +1,16 @@
"""Tool for discovering agents from marketplace and user library."""
"""Tool for discovering agents from marketplace."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .agent_search import search_agents
from .base import BaseTool
from .models import (
AgentCarouselResponse,
AgentInfo,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
from .models import ToolResponseBase
class FindAgentTool(BaseTool):
"""Tool for discovering agents based on user needs."""
"""Tool for discovering agents from the marketplace."""
@property
def name(self) -> str:
@@ -46,84 +36,11 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
"""Search for agents in the marketplace.
Args:
user_id: User ID (may be anonymous)
session_id: Chat session ID
query: Search query
Returns:
AgentCarouselResponse: List of agents found in the marketplace
NoResultsResponse: No agents found in the marketplace
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
agents = []
try:
logger.info(f"Searching marketplace for: {query}")
store_results = await store_db.get_store_agents(
search_query=query,
page_size=5,
)
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
for agent in store_results.agents:
agent_id = f"{agent.creator}/{agent.slug}"
logger.info(f"Building agent ID = {agent_id}")
agents.append(
AgentInfo(
id=agent_id,
name=agent.agent_name,
description=agent.description or "",
source="marketplace",
in_library=False,
creator=agent.creator,
category="general",
rating=agent.rating,
runs=agent.runs,
is_featured=False,
),
)
except NotFoundError:
pass
except DatabaseError as e:
logger.error(f"Error searching agents: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search for agents. Please try again.",
error=str(e),
session_id=session_id,
)
if not agents:
return NoResultsResponse(
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace. If you have 3 consecutive find_agent tool calls results and found no agents. Please stop trying and ask the user if there is anything else you can help with.",
session_id=session_id,
suggestions=[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
],
)
# Return formatted carousel
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
return AgentCarouselResponse(
message="Now you have found some options for the user to choose from. You can add a link to a recommended agent at: /marketplace/agent/agent_id Please ask the user if they would like to use any of these agents. If they do, please call the get_agent_details tool for this agent.",
title=title,
agents=agents,
count=len(agents),
session_id=session_id,
return await search_agents(
query=kwargs.get("query", "").strip(),
source="marketplace",
session_id=session.session_id,
user_id=user_id,
)

View File

@@ -0,0 +1,192 @@
import logging
from typing import Any
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.api.features.chat.tools.models import (
BlockInfoSummary,
BlockInputFieldInfo,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.data.block import get_block
logger = logging.getLogger(__name__)
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@property
def name(self) -> str:
return "find_block"
@property
def description(self) -> str:
return (
"Search for available blocks by name or description. "
"Blocks are reusable components that perform specific tasks like "
"sending emails, making API calls, processing text, etc. "
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
"The response includes each block's id, required_inputs, and input_schema."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find blocks by name or description. "
"Use keywords like 'email', 'http', 'text', 'ai', etc."
),
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
Args:
user_id: User ID (required)
session: Chat session
query: Search query
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
)
try:
# Search for blocks using hybrid search
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=10,
)
if not results:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms",
],
session_id=session_id,
)
# Enrich results with full block information
blocks: list[BlockInfoSummary] = []
for result in results:
block_id = result["content_id"]
block = get_block(block_id)
if block:
# Get input/output schemas
input_schema = {}
output_schema = {}
try:
input_schema = block.input_schema.jsonschema()
except Exception:
pass
try:
output_schema = block.output_schema.jsonschema()
except Exception:
pass
# Get categories from block instance
categories = []
if hasattr(block, "categories") and block.categories:
categories = [cat.value for cat in block.categories]
# Extract required inputs for easier use
required_inputs: list[BlockInputFieldInfo] = []
if input_schema:
properties = input_schema.get("properties", {})
required_fields = set(input_schema.get("required", []))
# Get credential field names to exclude from required inputs
credentials_fields = set(
block.input_schema.get_credentials_fields().keys()
)
for field_name, field_schema in properties.items():
# Skip credential fields - they're handled separately
if field_name in credentials_fields:
continue
required_inputs.append(
BlockInputFieldInfo(
name=field_name,
type=field_schema.get("type", "string"),
description=field_schema.get("description", ""),
required=field_name in required_fields,
default=field_schema.get("default"),
)
)
blocks.append(
BlockInfoSummary(
id=block_id,
name=block.name,
description=block.description or "",
categories=categories,
input_schema=input_schema,
output_schema=output_schema,
required_inputs=required_inputs,
)
)
if not blocks:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
],
session_id=session_id,
)
return BlockListResponse(
message=(
f"Found {len(blocks)} block(s) matching '{query}'. "
"To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
),
blocks=blocks,
count=len(blocks),
query=query,
session_id=session_id,
)
except Exception as e:
logger.error(f"Error searching blocks: {e}", exc_info=True)
return ErrorResponse(
message="Failed to search blocks",
error=str(e),
session_id=session_id,
)

View File

@@ -0,0 +1,52 @@
"""Tool for searching agents in the user's library."""
from typing import Any
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool
from .models import ToolResponseBase
class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library."""
@property
def name(self) -> str:
return "find_library_agent"
@property
def description(self) -> str:
return (
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query to find agents by name or description.",
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=kwargs.get("query", "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,
)

View File

@@ -0,0 +1,148 @@
"""GetDocPageTool - Fetch full content of a documentation page."""
import logging
from pathlib import Path
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
# Base URL for documentation (can be configured)
DOCS_BASE_URL = "https://docs.agpt.co"
class GetDocPageTool(BaseTool):
"""Tool for fetching full content of a documentation page."""
@property
def name(self) -> str:
return "get_doc_page"
@property
def description(self) -> str:
return (
"Get the full content of a documentation page by its path. "
"Use this after search_docs to read the complete content of a relevant page."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": (
"The path to the documentation file, as returned by search_docs. "
"Example: 'platform/block-sdk-guide.md'"
),
},
},
"required": ["path"],
}
@property
def requires_auth(self) -> bool:
return False # Documentation is public
def _get_docs_root(self) -> Path:
"""Get the documentation root directory."""
this_file = Path(__file__)
project_root = this_file.parent.parent.parent.parent.parent.parent.parent.parent
return project_root / "docs"
def _extract_title(self, content: str, fallback: str) -> str:
"""Extract title from markdown content."""
lines = content.split("\n")
for line in lines:
if line.startswith("# "):
return line[2:].strip()
return fallback
def _make_doc_url(self, path: str) -> str:
"""Create a URL for a documentation page."""
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Fetch full content of a documentation page.
Args:
user_id: User ID (not required for docs)
session: Chat session
path: Path to the documentation file
Returns:
DocPageResponse: Full document content
ErrorResponse: Error message
"""
path = kwargs.get("path", "").strip()
session_id = session.session_id if session else None
if not path:
return ErrorResponse(
message="Please provide a documentation path.",
error="Missing path parameter",
session_id=session_id,
)
# Sanitize path to prevent directory traversal
if ".." in path or path.startswith("/"):
return ErrorResponse(
message="Invalid documentation path.",
error="invalid_path",
session_id=session_id,
)
docs_root = self._get_docs_root()
full_path = docs_root / path
if not full_path.exists():
return ErrorResponse(
message=f"Documentation page not found: {path}",
error="not_found",
session_id=session_id,
)
# Ensure the path is within docs root
try:
full_path.resolve().relative_to(docs_root.resolve())
except ValueError:
return ErrorResponse(
message="Invalid documentation path.",
error="invalid_path",
session_id=session_id,
)
try:
content = full_path.read_text(encoding="utf-8")
title = self._extract_title(content, path)
return DocPageResponse(
message=f"Retrieved documentation page: {title}",
title=title,
path=path,
content=content,
doc_url=self._make_doc_url(path),
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to read documentation page {path}: {e}")
return ErrorResponse(
message=f"Failed to read documentation page: {str(e)}",
error="read_failed",
session_id=session_id,
)

View File

@@ -1,5 +1,6 @@
"""Pydantic models for tool responses."""
from datetime import datetime
from enum import Enum
from typing import Any
@@ -11,14 +12,19 @@ from backend.data.model import CredentialsMetaInput
class ResponseType(str, Enum):
"""Types of tool responses."""
AGENT_CAROUSEL = "agent_carousel"
AGENTS_FOUND = "agents_found"
AGENT_DETAILS = "agent_details"
SETUP_REQUIREMENTS = "setup_requirements"
EXECUTION_STARTED = "execution_started"
NEED_LOGIN = "need_login"
ERROR = "error"
NO_RESULTS = "no_results"
SUCCESS = "success"
AGENT_OUTPUT = "agent_output"
UNDERSTANDING_UPDATED = "understanding_updated"
BLOCK_LIST = "block_list"
BLOCK_OUTPUT = "block_output"
DOC_SEARCH_RESULTS = "doc_search_results"
DOC_PAGE = "doc_page"
# Base response model
@@ -51,14 +57,14 @@ class AgentInfo(BaseModel):
graph_id: str | None = None
class AgentCarouselResponse(ToolResponseBase):
class AgentsFoundResponse(ToolResponseBase):
"""Response for find_agent tool."""
type: ResponseType = ResponseType.AGENT_CAROUSEL
type: ResponseType = ResponseType.AGENTS_FOUND
title: str = "Available Agents"
agents: list[AgentInfo]
count: int
name: str = "agent_carousel"
name: str = "agents_found"
class NoResultsResponse(ToolResponseBase):
@@ -173,3 +179,117 @@ class ErrorResponse(ToolResponseBase):
type: ResponseType = ResponseType.ERROR
error: str | None = None
details: dict[str, Any] | None = None
# Agent output models
class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs."""
execution_id: str
status: str
started_at: datetime | None = None
ended_at: datetime | None = None
outputs: dict[str, list[Any]]
inputs_summary: dict[str, Any] | None = None
class AgentOutputResponse(ToolResponseBase):
"""Response for agent_output tool."""
type: ResponseType = ResponseType.AGENT_OUTPUT
agent_name: str
agent_id: str
library_agent_id: str | None = None
library_agent_link: str | None = None
execution: ExecutionOutputInfo | None = None
available_executions: list[dict[str, Any]] | None = None
total_executions: int = 0
# Business understanding models
class UnderstandingUpdatedResponse(ToolResponseBase):
"""Response for add_understanding tool."""
type: ResponseType = ResponseType.UNDERSTANDING_UPDATED
updated_fields: list[str] = Field(default_factory=list)
current_understanding: dict[str, Any] = Field(default_factory=dict)
# Documentation search models
class DocSearchResult(BaseModel):
"""A single documentation search result."""
title: str
path: str
section: str
snippet: str # Short excerpt for UI display
score: float
doc_url: str | None = None
class DocSearchResultsResponse(ToolResponseBase):
"""Response for search_docs tool."""
type: ResponseType = ResponseType.DOC_SEARCH_RESULTS
results: list[DocSearchResult]
count: int
query: str
class DocPageResponse(ToolResponseBase):
"""Response for get_doc_page tool."""
type: ResponseType = ResponseType.DOC_PAGE
title: str
path: str
content: str # Full document content
doc_url: str | None = None
# Block models
class BlockInputFieldInfo(BaseModel):
"""Information about a block input field."""
name: str
type: str
description: str = ""
required: bool = False
default: Any | None = None
class BlockInfoSummary(BaseModel):
"""Summary of a block for search results."""
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any]
output_schema: dict[str, Any]
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of required input fields for this block",
)
class BlockListResponse(ToolResponseBase):
"""Response for find_block tool."""
type: ResponseType = ResponseType.BLOCK_LIST
blocks: list[BlockInfoSummary]
count: int
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the required fields from input_schema."
)
class BlockOutputResponse(ToolResponseBase):
"""Response for run_block tool."""
type: ResponseType = ResponseType.BLOCK_OUTPUT
block_id: str
block_name: str
outputs: dict[str, list[Any]]
success: bool = True

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
@@ -57,6 +58,7 @@ class RunAgentInput(BaseModel):
"""Input parameters for the run_agent tool."""
username_agent_slug: str = ""
library_agent_id: str = ""
inputs: dict[str, Any] = Field(default_factory=dict)
use_defaults: bool = False
schedule_name: str = ""
@@ -64,7 +66,12 @@ class RunAgentInput(BaseModel):
timezone: str = "UTC"
@field_validator(
"username_agent_slug", "schedule_name", "cron", "timezone", mode="before"
"username_agent_slug",
"library_agent_id",
"schedule_name",
"cron",
"timezone",
mode="before",
)
@classmethod
def strip_strings(cls, v: Any) -> Any:
@@ -90,7 +97,7 @@ class RunAgentTool(BaseTool):
@property
def description(self) -> str:
return """Run or schedule an agent from the marketplace.
return """Run or schedule an agent from the marketplace or user's library.
The tool automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
@@ -98,6 +105,10 @@ class RunAgentTool(BaseTool):
- Executes immediately if all requirements are met
- Schedules execution if cron expression is provided
Identify the agent using either:
- username_agent_slug: Marketplace format 'username/agent-name'
- library_agent_id: ID of an agent in the user's library
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
@property
@@ -109,6 +120,10 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "Agent identifier in format 'username/agent-name'",
},
"library_agent_id": {
"type": "string",
"description": "Library agent ID from user's library",
},
"inputs": {
"type": "object",
"description": "Input values for the agent",
@@ -131,7 +146,7 @@ class RunAgentTool(BaseTool):
"description": "IANA timezone for schedule (default: UTC)",
},
},
"required": ["username_agent_slug"],
"required": [],
}
@property
@@ -149,10 +164,16 @@ class RunAgentTool(BaseTool):
params = RunAgentInput(**kwargs)
session_id = session.session_id
# Validate agent slug format
if not params.username_agent_slug or "/" not in params.username_agent_slug:
# Validate at least one identifier is provided
has_slug = params.username_agent_slug and "/" in params.username_agent_slug
has_library_id = bool(params.library_agent_id)
if not has_slug and not has_library_id:
return ErrorResponse(
message="Please provide an agent slug in format 'username/agent-name'",
message=(
"Please provide either a username_agent_slug "
"(format 'username/agent-name') or a library_agent_id"
),
session_id=session_id,
)
@@ -167,13 +188,41 @@ class RunAgentTool(BaseTool):
is_schedule = bool(params.schedule_name or params.cron)
try:
# Step 1: Fetch agent details (always happens first)
username, agent_name = params.username_agent_slug.split("/", 1)
graph, store_agent = await fetch_graph_from_store_slug(username, agent_name)
# Step 1: Fetch agent details
graph: GraphModel | None = None
library_agent = None
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
return ErrorResponse(
message=f"Library agent '{params.library_agent_id}' not found",
session_id=session_id,
)
# Get the graph from the library agent
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
)
else:
# Fetch from marketplace slug
username, agent_name = params.username_agent_slug.split("/", 1)
graph, _ = await fetch_graph_from_store_slug(username, agent_name)
if not graph:
identifier = (
params.library_agent_id
if has_library_id
else params.username_agent_slug
)
return ErrorResponse(
message=f"Agent '{params.username_agent_slug}' not found in marketplace",
message=f"Agent '{identifier}' not found",
session_id=session_id,
)

View File

@@ -1,4 +1,5 @@
import uuid
from unittest.mock import AsyncMock, patch
import orjson
import pytest
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
setup_firecrawl_test_data = setup_firecrawl_test_data
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(scope="session")
async def test_run_agent(setup_test_data):
"""Test that the run_agent tool successfully executes an approved agent"""
@@ -46,11 +58,11 @@ async def test_run_agent(setup_test_data):
# Verify the response
assert response is not None
assert hasattr(response, "result")
assert hasattr(response, "output")
# Parse the result JSON to verify the execution started
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert "execution_id" in result_data
assert "graph_id" in result_data
assert result_data["graph_id"] == graph.id
@@ -86,11 +98,11 @@ async def test_run_agent_missing_inputs(setup_test_data):
# Verify that we get an error response
assert response is not None
assert hasattr(response, "result")
assert hasattr(response, "output")
# The tool should return an ErrorResponse when setup info indicates not ready
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert "message" in result_data
@@ -118,10 +130,10 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
# Verify that we get an error response
assert response is not None
assert hasattr(response, "result")
assert hasattr(response, "output")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
assert "message" in result_data
# Should get an error about failed setup or not found
assert any(
@@ -158,12 +170,12 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
# Verify the response
assert response is not None
assert hasattr(response, "result")
assert hasattr(response, "output")
# Parse the result JSON to verify the execution started
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should successfully start execution since credentials are available
assert "execution_id" in result_data
@@ -195,9 +207,9 @@ async def test_run_agent_shows_available_inputs_when_none_provided(setup_test_da
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return agent_details type showing available inputs
assert result_data.get("type") == "agent_details"
@@ -230,9 +242,9 @@ async def test_run_agent_with_use_defaults(setup_test_data):
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should execute successfully
assert "execution_id" in result_data
@@ -260,9 +272,9 @@ async def test_run_agent_missing_credentials(setup_firecrawl_test_data):
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return setup_requirements type with missing credentials
assert result_data.get("type") == "setup_requirements"
@@ -292,9 +304,9 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return error
assert result_data.get("type") == "error"
@@ -305,9 +317,10 @@ async def test_run_agent_invalid_slug_format(setup_test_data):
async def test_run_agent_unauthenticated():
"""Test that run_agent returns need_login for unauthenticated users."""
tool = RunAgentTool()
session = make_session(user_id=None)
# Session has a user_id (session owner), but we test tool execution without user_id
session = make_session(user_id="test-session-owner")
# Execute without user_id
# Execute without user_id to test unauthenticated behavior
response = await tool.execute(
user_id=None,
session_id=str(uuid.uuid4()),
@@ -318,9 +331,9 @@ async def test_run_agent_unauthenticated():
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Base tool returns need_login type for unauthenticated users
assert result_data.get("type") == "need_login"
@@ -350,9 +363,9 @@ async def test_run_agent_schedule_without_cron(setup_test_data):
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return error about missing cron
assert result_data.get("type") == "error"
@@ -382,9 +395,9 @@ async def test_run_agent_schedule_without_name(setup_test_data):
)
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return error about missing schedule_name
assert result_data.get("type") == "error"

View File

@@ -0,0 +1,297 @@
"""Tool for executing blocks directly."""
import logging
from collections import defaultdict
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .models import (
BlockOutputResponse,
ErrorResponse,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
logger = logging.getLogger(__name__)
class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs."""
@property
def name(self) -> str:
return "run_block"
@property
def description(self) -> str:
return (
"Execute a specific block with the provided input data. "
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
"do NOT guess or make up block IDs. "
"Use the 'id' from find_block results and provide input_data "
"matching the block's required_inputs."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"block_id": {
"type": "string",
"description": (
"The block's 'id' field from find_block results. "
"NEVER guess this - always get it from find_block first."
),
},
"input_data": {
"type": "object",
"description": (
"Input values for the block. Use the 'required_inputs' field "
"from find_block to see what fields are needed."
),
},
},
"required": ["block_id", "input_data"],
}
@property
def requires_auth(self) -> bool:
return True
async def _check_block_credentials(
self,
user_id: str,
block: Any,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Check if user has required credentials for a block.
Returns:
tuple[matched_credentials, missing_credentials]
"""
matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
# Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info()
if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute a block with the given input data.
Args:
user_id: User ID (required)
session: Chat session
block_id: Block UUID to execute
input_data: Input values for the block
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
session_id = session.session_id
if not block_id:
return ErrorResponse(
message="Please provide a block_id",
session_id=session_id,
)
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
# Get the block
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block
)
if missing_credentials:
# Return setup requirements response with missing credentials
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
return SetupRequirementsResponse(
message=(
f"Block '{block.name}' requires credentials that are not configured. "
"Please set up the required credentials before running this block."
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=block_id,
agent_name=block.name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_creds_dict,
ready_to_run=False,
),
requirements={
"credentials": [c.model_dump() for c in missing_credentials],
"inputs": self._get_inputs_list(block),
"execution_modes": ["immediate"],
},
),
graph_id=None,
graph_version=None,
)
try:
# Fetch actual credentials and prepare kwargs for block execution
# Create execution context with defaults (blocks may require it)
exec_kwargs: dict[str, Any] = {
"user_id": user_id,
"execution_context": ExecutionContext(),
}
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
user_id, cred_meta.id, lock=False
)
if actual_credentials:
exec_kwargs[field_name] = actual_credentials
else:
return ErrorResponse(
message=f"Failed to retrieve credentials for {field_name}",
session_id=session_id,
)
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,
session_id=session_id,
)
except BlockError as e:
logger.warning(f"Block execution failed: {e}")
return ErrorResponse(
message=f"Block execution failed: {e}",
error=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Unexpected error executing block: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to execute block: {str(e)}",
error=str(e),
session_id=session_id,
)
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
return inputs_list

View File

@@ -0,0 +1,208 @@
"""SearchDocsTool - Search documentation using hybrid search."""
import logging
from typing import Any
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__)
# Base URL for documentation (can be configured)
DOCS_BASE_URL = "https://docs.agpt.co"
# Maximum number of results to return
MAX_RESULTS = 5
# Snippet length for preview
SNIPPET_LENGTH = 200
class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation."""
@property
def name(self) -> str:
return "search_docs"
@property
def description(self) -> str:
return (
"Search the AutoGPT platform documentation for information about "
"how to use the platform, build agents, configure blocks, and more. "
"Returns relevant documentation sections. Use get_doc_page to read full content."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find relevant documentation. "
"Use natural language to describe what you're looking for."
),
},
},
"required": ["query"],
}
@property
def requires_auth(self) -> bool:
return False # Documentation is public
def _create_snippet(self, content: str, max_length: int = SNIPPET_LENGTH) -> str:
"""Create a short snippet from content for preview."""
# Remove markdown formatting for cleaner snippet
clean_content = content.replace("#", "").replace("*", "").replace("`", "")
# Remove extra whitespace
clean_content = " ".join(clean_content.split())
if len(clean_content) <= max_length:
return clean_content
# Truncate at word boundary
truncated = clean_content[:max_length]
last_space = truncated.rfind(" ")
if last_space > max_length // 2:
truncated = truncated[:last_space]
return truncated + "..."
def _make_doc_url(self, path: str) -> str:
"""Create a URL for a documentation page."""
# Remove file extension for URL
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search documentation and return relevant sections.
Args:
user_id: User ID (not required for docs)
session: Chat session
query: Search query
Returns:
DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not query:
return ErrorResponse(
message="Please provide a search query.",
error="Missing query parameter",
session_id=session_id,
)
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
min_score=0.1, # Lower threshold for docs
)
if not results:
return NoResultsResponse(
message=f"No documentation found for '{query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
"Check for typos in your query",
],
session_id=session_id,
)
# Deduplicate by document path (keep highest scoring section per doc)
seen_docs: dict[str, dict[str, Any]] = {}
for result in results:
metadata = result.get("metadata", {})
doc_path = metadata.get("path", "")
if not doc_path:
continue
# Keep the highest scoring result for each document
if doc_path not in seen_docs:
seen_docs[doc_path] = result
elif result.get("combined_score", 0) > seen_docs[doc_path].get(
"combined_score", 0
):
seen_docs[doc_path] = result
# Sort by score and take top MAX_RESULTS
deduplicated = sorted(
seen_docs.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True,
)[:MAX_RESULTS]
if not deduplicated:
return NoResultsResponse(
message=f"No documentation found for '{query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
],
session_id=session_id,
)
# Build response
doc_results: list[DocSearchResult] = []
for result in deduplicated:
metadata = result.get("metadata", {})
doc_path = metadata.get("path", "")
doc_title = metadata.get("doc_title", "")
section_title = metadata.get("section_title", "")
searchable_text = result.get("searchable_text", "")
score = result.get("combined_score", 0)
doc_results.append(
DocSearchResult(
title=doc_title or section_title or doc_path,
path=doc_path,
section=section_title,
snippet=self._create_snippet(searchable_text),
score=round(score, 3),
doc_url=self._make_doc_url(doc_path),
)
)
return DocSearchResultsResponse(
message=f"Found {len(doc_results)} relevant documentation sections.",
results=doc_results,
count=len(doc_results),
query=query,
session_id=session_id,
)
except Exception as e:
logger.error(f"Documentation search failed: {e}")
return ErrorResponse(
message=f"Failed to search documentation: {str(e)}",
error="search_failed",
session_id=session_id,
)

View File

@@ -35,11 +35,7 @@ from backend.data.model import (
OAuth2Credentials,
UserIntegrations,
)
from backend.data.onboarding import (
OnboardingStep,
complete_onboarding_step,
increment_runs,
)
from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
@@ -175,6 +171,7 @@ async def callback(
f"Successfully processed OAuth callback for user {user_id} "
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
@@ -193,6 +190,7 @@ async def list_credentials(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_all_creds(user_id)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -215,6 +213,7 @@ async def list_credentials_by_provider(
user_id: Annotated[str, Security(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = await creds_manager.store.get_creds_by_provider(user_id, provider)
return [
CredentialsMetaResponse(
id=cred.id,
@@ -378,7 +377,6 @@ async def webhook_ingress_generic(
return
await complete_onboarding_step(user_id, OnboardingStep.TRIGGER_WEBHOOK)
await increment_runs(user_id)
# Execute all triggers concurrently for better performance
tasks = []
@@ -831,6 +829,18 @@ async def list_providers() -> List[str]:
return all_providers
@router.get("/providers/system", response_model=List[str])
async def list_system_providers() -> List[str]:
"""
Get a list of providers that have platform credits (system credentials) available.
These providers can be used without the user providing their own API keys.
"""
from backend.integrations.credentials_store import SYSTEM_PROVIDERS
return list(SYSTEM_PROVIDERS)
@router.get("/providers/names", response_model=ProviderNamesResponse)
async def get_provider_names() -> ProviderNamesResponse:
"""

View File

@@ -489,7 +489,7 @@ async def update_agent_version_in_library(
agent_graph_version: int,
) -> library_model.LibraryAgent:
"""
Updates the agent version in the library if useGraphIsActiveVersion is True.
Updates the agent version in the library for any agent owned by the user.
Args:
user_id: Owner of the LibraryAgent.
@@ -498,20 +498,31 @@ async def update_agent_version_in_library(
Raises:
DatabaseError: If there's an error with the update.
NotFoundError: If no library agent is found for this user and agent.
"""
logger.debug(
f"Updating agent version in library for user #{user_id}, "
f"agent #{agent_graph_id} v{agent_graph_version}"
)
try:
library_agent = await prisma.models.LibraryAgent.prisma().find_first_or_raise(
async with transaction() as tx:
library_agent = await prisma.models.LibraryAgent.prisma(tx).find_first_or_raise(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"useGraphIsActiveVersion": True,
},
)
lib = await prisma.models.LibraryAgent.prisma().update(
# Delete any conflicting LibraryAgent for the target version
await prisma.models.LibraryAgent.prisma(tx).delete_many(
where={
"userId": user_id,
"agentGraphId": agent_graph_id,
"agentGraphVersion": agent_graph_version,
"id": {"not": library_agent.id},
}
)
lib = await prisma.models.LibraryAgent.prisma(tx).update(
where={"id": library_agent.id},
data={
"AgentGraph": {
@@ -525,13 +536,13 @@ async def update_agent_version_in_library(
},
include={"AgentGraph": True},
)
if lib is None:
raise NotFoundError(f"Library agent {library_agent.id} not found")
return library_model.LibraryAgent.from_db(lib)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {e}")
raise DatabaseError("Failed to update agent version in library") from e
if lib is None:
raise NotFoundError(
f"Failed to update library agent for {agent_graph_id} v{agent_graph_version}"
)
return library_model.LibraryAgent.from_db(lib)
async def update_library_agent(
@@ -825,6 +836,7 @@ async def add_store_agent_to_library(
}
},
"isCreatedByUser": False,
"useGraphIsActiveVersion": False,
"settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
),

View File

@@ -48,6 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
id: str
graph_id: str
graph_version: int
owner_user_id: str # ID of user who owns/created this agent graph
image_url: str | None
@@ -163,6 +164,7 @@ class LibraryAgent(pydantic.BaseModel):
id=agent.id,
graph_id=agent.agentGraphId,
graph_version=agent.agentGraphVersion,
owner_user_id=agent.userId,
image_url=agent.imageUrl,
creator_name=creator_name,
creator_image_url=creator_image_url,

View File

@@ -8,7 +8,6 @@ from backend.data.execution import GraphExecutionMeta
from backend.data.graph import get_graph
from backend.data.integrations import get_webhook
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.executor.utils import add_graph_execution, make_node_credentials_input_map
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks import get_webhook_manager
@@ -403,8 +402,6 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs
await increment_runs(user_id)
return await add_graph_execution(
user_id=user_id,
graph_id=preset.graph_id,

View File

@@ -42,6 +42,7 @@ async def test_get_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,
@@ -64,6 +65,7 @@ async def test_get_library_agents_success(
id="test-agent-2",
graph_id="test-agent-2",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 2",
description="Test Description 2",
image_url=None,
@@ -138,6 +140,7 @@ async def test_get_favorite_library_agents_success(
id="test-agent-1",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Favorite Agent 1",
description="Test Favorite Description 1",
image_url=None,
@@ -205,6 +208,7 @@ def test_add_agent_to_library_success(
id="test-library-agent-id",
graph_id="test-agent-1",
graph_version=1,
owner_user_id=test_user_id,
name="Test Agent 1",
description="Test Description 1",
image_url=None,

View File

@@ -0,0 +1,610 @@
"""
Content Type Handlers for Unified Embeddings
Pluggable system for different content sources (store agents, blocks, docs).
Each handler knows how to fetch and process its content type for embedding.
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from prisma.enums import ContentType
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
content_id: str # Unique identifier (DB ID or file path)
content_type: ContentType
searchable_text: str # Combined text for embedding
metadata: dict[str, Any] # Content-specific metadata
user_id: str | None = None # For user-scoped content
class ContentHandler(ABC):
"""Base handler for fetching and processing content for embeddings."""
@property
@abstractmethod
def content_type(self) -> ContentType:
"""The ContentType this handler manages."""
pass
@abstractmethod
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""
Fetch items that don't have embeddings yet.
Args:
batch_size: Maximum number of items to return
Returns:
List of ContentItem objects ready for embedding
"""
pass
@abstractmethod
async def get_stats(self) -> dict[str, int]:
"""
Get statistics about embedding coverage.
Returns:
Dict with keys: total, with_embeddings, without_embeddings
"""
pass
class StoreAgentHandler(ContentHandler):
"""Handler for marketplace store agent listings."""
@property
def content_type(self) -> ContentType:
return ContentType.STORE_AGENT
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch approved store listings without embeddings."""
from backend.api.features.store.embeddings import build_searchable_text
missing = await query_raw_with_schema(
"""
SELECT
slv.id,
slv.name,
slv.description,
slv."subHeading",
slv.categories
FROM {schema_prefix}"StoreListingVersion" slv
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND uce."contentId" IS NULL
LIMIT $1
""",
batch_size,
)
return [
ContentItem(
content_id=row["id"],
content_type=ContentType.STORE_AGENT,
searchable_text=build_searchable_text(
name=row["name"],
description=row["description"],
sub_heading=row["subHeading"],
categories=row["categories"] or [],
),
metadata={
"name": row["name"],
"categories": row["categories"] or [],
},
user_id=None, # Store agents are public
)
for row in missing
]
async def get_stats(self) -> dict[str, int]:
"""Get statistics about store agent embedding coverage."""
# Count approved versions
approved_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
AND "isDeleted" = false
"""
)
total_approved = approved_result[0]["count"] if approved_result else 0
# Count versions with embeddings
embedded_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"StoreListingVersion" slv
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
"""
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_approved,
"with_embeddings": with_embeddings,
"without_embeddings": total_approved - with_embeddings,
}
class BlockHandler(ContentHandler):
"""Handler for block definitions (Python classes)."""
@property
def content_type(self) -> ContentType:
return ContentType.BLOCK
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch blocks without embeddings."""
from backend.data.block import get_blocks
# Get all available blocks
all_blocks = get_blocks()
# Check which ones have embeddings
if not all_blocks:
return []
block_ids = list(all_blocks.keys())
# Query for existing embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*block_ids,
)
existing_ids = {row["contentId"] for row in existing_result}
missing_blocks = [
(block_id, block_cls)
for block_id, block_cls in all_blocks.items()
if block_id not in existing_ids
]
# Convert to ContentItem
items = []
for block_id, block_cls in missing_blocks[:batch_size]:
try:
block_instance = block_cls()
# Build searchable text from block metadata
parts = []
if hasattr(block_instance, "name") and block_instance.name:
parts.append(block_instance.name)
if (
hasattr(block_instance, "description")
and block_instance.description
):
parts.append(block_instance.description)
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in categories] if categories else []
)
items.append(
ContentItem(
content_id=block_id,
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": getattr(block_instance, "name", ""),
"categories": categories_list,
},
user_id=None, # Blocks are public
)
)
except Exception as e:
logger.warning(f"Failed to process block {block_id}: {e}")
continue
return items
async def get_stats(self) -> dict[str, int]:
"""Get statistics about block embedding coverage."""
from backend.data.block import get_blocks
all_blocks = get_blocks()
total_blocks = len(all_blocks)
if total_blocks == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
block_ids = list(all_blocks.keys())
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
embedded_result = await query_raw_with_schema(
f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*block_ids,
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_blocks,
"with_embeddings": with_embeddings,
"without_embeddings": total_blocks - with_embeddings,
}
@dataclass
class MarkdownSection:
"""Represents a section of a markdown document."""
title: str # Section heading text
content: str # Section content (including the heading line)
level: int # Heading level (1 for #, 2 for ##, etc.)
index: int # Section index within the document
class DocumentationHandler(ContentHandler):
"""Handler for documentation files (.md/.mdx).
Chunks documents by markdown headings to create multiple embeddings per file.
Each section (## heading) becomes a separate embedding for better retrieval.
"""
@property
def content_type(self) -> ContentType:
return ContentType.DOCUMENTATION
def _get_docs_root(self) -> Path:
"""Get the documentation root directory."""
# content_handlers.py is at: backend/backend/api/features/store/content_handlers.py
# Need to go up to project root then into docs/
# In container: /app/autogpt_platform/backend/backend/api/features/store -> /app/docs
# In development: /repo/autogpt_platform/backend/backend/api/features/store -> /repo/docs
this_file = Path(
__file__
) # .../backend/backend/api/features/store/content_handlers.py
project_root = (
this_file.parent.parent.parent.parent.parent.parent.parent
) # -> /app or /repo
docs_root = project_root / "docs"
return docs_root
def _extract_doc_title(self, file_path: Path) -> str:
"""Extract the document title from a markdown file."""
try:
content = file_path.read_text(encoding="utf-8")
lines = content.split("\n")
# Try to extract title from first # heading
for line in lines:
if line.startswith("# "):
return line[2:].strip()
# If no title found, use filename
return file_path.stem.replace("-", " ").replace("_", " ").title()
except Exception as e:
logger.warning(f"Failed to read title from {file_path}: {e}")
return file_path.stem.replace("-", " ").replace("_", " ").title()
def _chunk_markdown_by_headings(
self, file_path: Path, min_heading_level: int = 2
) -> list[MarkdownSection]:
"""
Split a markdown file into sections based on headings.
Args:
file_path: Path to the markdown file
min_heading_level: Minimum heading level to split on (default: 2 for ##)
Returns:
List of MarkdownSection objects, one per section.
If no headings found, returns a single section with all content.
"""
try:
content = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning(f"Failed to read {file_path}: {e}")
return []
lines = content.split("\n")
sections: list[MarkdownSection] = []
current_section_lines: list[str] = []
current_title = ""
current_level = 0
section_index = 0
doc_title = ""
for line in lines:
# Check if line is a heading
if line.startswith("#"):
# Count heading level
level = 0
for char in line:
if char == "#":
level += 1
else:
break
heading_text = line[level:].strip()
# Track document title (level 1 heading)
if level == 1 and not doc_title:
doc_title = heading_text
# Don't create a section for just the title - add it to first section
current_section_lines.append(line)
continue
# Check if this heading should start a new section
if level >= min_heading_level:
# Save previous section if it has content
if current_section_lines:
section_content = "\n".join(current_section_lines).strip()
if section_content:
# Use doc title for first section if no specific title
title = current_title if current_title else doc_title
if not title:
title = file_path.stem.replace("-", " ").replace(
"_", " "
)
sections.append(
MarkdownSection(
title=title,
content=section_content,
level=current_level if current_level else 1,
index=section_index,
)
)
section_index += 1
# Start new section
current_section_lines = [line]
current_title = heading_text
current_level = level
else:
# Lower level heading (e.g., # when splitting on ##)
current_section_lines.append(line)
else:
current_section_lines.append(line)
# Don't forget the last section
if current_section_lines:
section_content = "\n".join(current_section_lines).strip()
if section_content:
title = current_title if current_title else doc_title
if not title:
title = file_path.stem.replace("-", " ").replace("_", " ")
sections.append(
MarkdownSection(
title=title,
content=section_content,
level=current_level if current_level else 1,
index=section_index,
)
)
# If no sections were created (no headings found), create one section with all content
if not sections and content.strip():
title = (
doc_title
if doc_title
else file_path.stem.replace("-", " ").replace("_", " ")
)
sections.append(
MarkdownSection(
title=title,
content=content.strip(),
level=1,
index=0,
)
)
return sections
def _make_section_content_id(self, doc_path: str, section_index: int) -> str:
"""Create a unique content ID for a document section.
Format: doc_path::section_index
Example: 'platform/getting-started.md::0'
"""
return f"{doc_path}::{section_index}"
def _parse_section_content_id(self, content_id: str) -> tuple[str, int]:
"""Parse a section content ID back into doc_path and section_index.
Returns: (doc_path, section_index)
"""
if "::" in content_id:
parts = content_id.rsplit("::", 1)
return parts[0], int(parts[1])
# Legacy format (whole document)
return content_id, 0
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
"""Fetch documentation sections without embeddings.
Chunks each document by markdown headings and creates embeddings for each section.
Content IDs use the format: 'path/to/doc.md::section_index'
"""
docs_root = self._get_docs_root()
if not docs_root.exists():
logger.warning(f"Documentation root not found: {docs_root}")
return []
# Find all .md and .mdx files
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
if not all_docs:
return []
# Build list of all sections from all documents
all_sections: list[tuple[str, Path, MarkdownSection]] = []
for doc_file in all_docs:
doc_path = str(doc_file.relative_to(docs_root))
sections = self._chunk_markdown_by_headings(doc_file)
for section in sections:
all_sections.append((doc_path, doc_file, section))
if not all_sections:
return []
# Generate content IDs for all sections
section_content_ids = [
self._make_section_content_id(doc_path, section.index)
for doc_path, _, section in all_sections
]
# Check which ones have embeddings
placeholders = ",".join([f"${i+1}" for i in range(len(section_content_ids))])
existing_result = await query_raw_with_schema(
f"""
SELECT "contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding"
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
AND "contentId" = ANY(ARRAY[{placeholders}])
""",
*section_content_ids,
)
existing_ids = {row["contentId"] for row in existing_result}
# Filter to missing sections
missing_sections = [
(doc_path, doc_file, section, content_id)
for (doc_path, doc_file, section), content_id in zip(
all_sections, section_content_ids
)
if content_id not in existing_ids
]
# Convert to ContentItem (up to batch_size)
items = []
for doc_path, doc_file, section, content_id in missing_sections[:batch_size]:
try:
# Get document title for context
doc_title = self._extract_doc_title(doc_file)
# Build searchable text with context
# Include doc title and section title for better search relevance
searchable_text = f"{doc_title} - {section.title}\n\n{section.content}"
items.append(
ContentItem(
content_id=content_id,
content_type=ContentType.DOCUMENTATION,
searchable_text=searchable_text,
metadata={
"doc_title": doc_title,
"section_title": section.title,
"section_index": section.index,
"heading_level": section.level,
"path": doc_path,
},
user_id=None, # Documentation is public
)
)
except Exception as e:
logger.warning(f"Failed to process section {content_id}: {e}")
continue
return items
def _get_all_section_content_ids(self, docs_root: Path) -> set[str]:
"""Get all current section content IDs from the docs directory.
Used for stats and cleanup to know what sections should exist.
"""
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
content_ids = set()
for doc_file in all_docs:
doc_path = str(doc_file.relative_to(docs_root))
sections = self._chunk_markdown_by_headings(doc_file)
for section in sections:
content_ids.add(self._make_section_content_id(doc_path, section.index))
return content_ids
async def get_stats(self) -> dict[str, int]:
"""Get statistics about documentation embedding coverage.
Counts sections (not documents) since each section gets its own embedding.
"""
docs_root = self._get_docs_root()
if not docs_root.exists():
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
# Get all section content IDs
all_section_ids = self._get_all_section_content_ids(docs_root)
total_sections = len(all_section_ids)
if total_sections == 0:
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
# Count embeddings in database for DOCUMENTATION type
embedded_result = await query_raw_with_schema(
"""
SELECT COUNT(*) as count
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = 'DOCUMENTATION'::{schema_prefix}"ContentType"
"""
)
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
return {
"total": total_sections,
"with_embeddings": with_embeddings,
"without_embeddings": total_sections - with_embeddings,
}
# Content handler registry
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
ContentType.STORE_AGENT: StoreAgentHandler(),
ContentType.BLOCK: BlockHandler(),
ContentType.DOCUMENTATION: DocumentationHandler(),
}

View File

@@ -0,0 +1,215 @@
"""
Integration tests for content handlers using real DB.
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
These tests use the real database but mock OpenAI calls.
"""
from unittest.mock import patch
import pytest
from backend.api.features.store.content_handlers import (
CONTENT_HANDLERS,
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
)
from backend.api.features.store.embeddings import (
EMBEDDING_DIM,
backfill_all_content_types,
ensure_content_embedding,
get_embedding_stats,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_real_db():
"""Test StoreAgentHandler with real database queries."""
handler = StoreAgentHandler()
# Get stats from real DB
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list (may be empty if all have embeddings)
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None
assert item.content_type.value == "STORE_AGENT"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_real_db():
"""Test BlockHandler with real database queries."""
handler = BlockHandler()
# Get stats from real DB
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0 # Should have at least some blocks
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None # Should be block UUID
assert item.content_type.value == "BLOCK"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_real_fs():
"""Test DocumentationHandler with real filesystem."""
handler = DocumentationHandler()
# Get stats from real filesystem
stats = await handler.get_stats()
# Stats should have correct structure
assert "total" in stats
assert "with_embeddings" in stats
assert "without_embeddings" in stats
assert stats["total"] >= 0
assert stats["with_embeddings"] >= 0
assert stats["without_embeddings"] >= 0
# Get missing items (max 1 to keep test fast)
items = await handler.get_missing_items(batch_size=1)
# Items should be list
assert isinstance(items, list)
if items:
item = items[0]
assert item.content_id is not None # Should be relative path
assert item.content_type.value == "DOCUMENTATION"
assert item.searchable_text != ""
assert item.user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_stats_all_types():
"""Test get_embedding_stats aggregates all content types."""
stats = await get_embedding_stats()
# Should have structure with by_type and totals
assert "by_type" in stats
assert "totals" in stats
# Check each content type is present
by_type = stats["by_type"]
assert "STORE_AGENT" in by_type
assert "BLOCK" in by_type
assert "DOCUMENTATION" in by_type
# Check totals are aggregated
totals = stats["totals"]
assert totals["total"] >= 0
assert totals["with_embeddings"] >= 0
assert totals["without_embeddings"] >= 0
assert "coverage_percent" in totals
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
async def test_ensure_content_embedding_blocks(mock_generate):
"""Test creating embeddings for blocks (mocked OpenAI)."""
# Mock OpenAI to return fake embedding
mock_generate.return_value = [0.1] * EMBEDDING_DIM
# Get one block without embedding
handler = BlockHandler()
items = await handler.get_missing_items(batch_size=1)
if not items:
pytest.skip("No blocks without embeddings")
item = items[0]
# Try to create embedding (OpenAI mocked)
result = await ensure_content_embedding(
content_type=item.content_type,
content_id=item.content_id,
searchable_text=item.searchable_text,
metadata=item.metadata,
user_id=item.user_id,
)
# Should succeed with mocked OpenAI
assert result is True
mock_generate.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
async def test_backfill_all_content_types_dry_run(mock_generate):
"""Test backfill_all_content_types processes all handlers in order."""
# Mock OpenAI to return fake embedding
mock_generate.return_value = [0.1] * EMBEDDING_DIM
# Run backfill with batch_size=1 to process max 1 per type
result = await backfill_all_content_types(batch_size=1)
# Should have results for all content types
assert "by_type" in result
assert "totals" in result
by_type = result["by_type"]
assert "BLOCK" in by_type
assert "STORE_AGENT" in by_type
assert "DOCUMENTATION" in by_type
# Each type should have correct structure
for content_type, type_result in by_type.items():
assert "processed" in type_result
assert "success" in type_result
assert "failed" in type_result
# Totals should aggregate
totals = result["totals"]
assert totals["processed"] >= 0
assert totals["success"] >= 0
assert totals["failed"] >= 0
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handler_registry():
"""Test all handlers are registered in correct order."""
from prisma.enums import ContentType
# All three types should be registered
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
assert ContentType.BLOCK in CONTENT_HANDLERS
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
# Check handler types
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)

View File

@@ -0,0 +1,381 @@
"""
E2E tests for content handlers (blocks, store agents, documentation).
Tests the full flow: discovering content → generating embeddings → storing.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store.content_handlers import (
CONTENT_HANDLERS,
BlockHandler,
DocumentationHandler,
StoreAgentHandler,
)
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_missing_items(mocker):
"""Test StoreAgentHandler fetches approved agents without embeddings."""
handler = StoreAgentHandler()
# Mock database query
mock_missing = [
{
"id": "agent-1",
"name": "Test Agent",
"description": "A test agent",
"subHeading": "Test heading",
"categories": ["AI", "Testing"],
}
]
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_missing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "agent-1"
assert items[0].content_type == ContentType.STORE_AGENT
assert "Test Agent" in items[0].searchable_text
assert "A test agent" in items[0].searchable_text
assert items[0].metadata["name"] == "Test Agent"
assert items[0].user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_store_agent_handler_get_stats(mocker):
"""Test StoreAgentHandler returns correct stats."""
handler = StoreAgentHandler()
# Mock approved count query
mock_approved = [{"count": 50}]
# Mock embedded count query
mock_embedded = [{"count": 30}]
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
side_effect=[mock_approved, mock_embedded],
):
stats = await handler.get_stats()
assert stats["total"] == 50
assert stats["with_embeddings"] == 30
assert stats["without_embeddings"] == 20
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_missing_items(mocker):
"""Test BlockHandler discovers blocks without embeddings."""
handler = BlockHandler()
# Mock get_blocks to return test blocks
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Calculator Block"
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
# Mock existing embeddings query (no embeddings exist)
mock_existing = []
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_existing,
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].content_id == "block-uuid-1"
assert items[0].content_type == ContentType.BLOCK
assert "Calculator Block" in items[0].searchable_text
assert "Performs calculations" in items[0].searchable_text
assert "MATH" in items[0].searchable_text
assert "expression: Math expression" in items[0].searchable_text
assert items[0].user_id is None
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_get_stats(mocker):
"""Test BlockHandler returns correct stats."""
handler = BlockHandler()
# Mock get_blocks
mock_blocks = {
"block-1": MagicMock(),
"block-2": MagicMock(),
"block-3": MagicMock(),
}
# Mock embedded count query (2 blocks have embeddings)
mock_embedded = [{"count": 2}]
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 3
assert stats["with_embeddings"] == 2
assert stats["without_embeddings"] == 1
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
"""Test DocumentationHandler discovers docs without embeddings."""
handler = DocumentationHandler()
# Create temporary docs directory with test files
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
# Mock _get_docs_root to return temp dir
with patch.object(handler, "_get_docs_root", return_value=docs_root):
# Mock existing embeddings query (no embeddings exist)
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 2
# Check guide.md (content_id format: doc_path::section_index)
guide_item = next(
(item for item in items if item.content_id == "guide.md::0"), None
)
assert guide_item is not None
assert guide_item.content_type == ContentType.DOCUMENTATION
assert "Getting Started" in guide_item.searchable_text
assert "This is a guide" in guide_item.searchable_text
assert guide_item.metadata["doc_title"] == "Getting Started"
assert guide_item.user_id is None
# Check api.mdx (content_id format: doc_path::section_index)
api_item = next(
(item for item in items if item.content_id == "api.mdx::0"), None
)
assert api_item is not None
assert "API Reference" in api_item.searchable_text
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_get_stats(tmp_path, mocker):
"""Test DocumentationHandler returns correct stats."""
handler = DocumentationHandler()
# Create temporary docs directory
docs_root = tmp_path / "docs"
docs_root.mkdir()
(docs_root / "doc1.md").write_text("# Doc 1")
(docs_root / "doc2.md").write_text("# Doc 2")
(docs_root / "doc3.mdx").write_text("# Doc 3")
# Mock embedded count query (1 doc has embedding)
mock_embedded = [{"count": 1}]
with patch.object(handler, "_get_docs_root", return_value=docs_root):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=mock_embedded,
):
stats = await handler.get_stats()
assert stats["total"] == 3
assert stats["with_embeddings"] == 1
assert stats["without_embeddings"] == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_title_extraction(tmp_path):
"""Test DocumentationHandler extracts title from markdown heading."""
handler = DocumentationHandler()
# Test with heading
doc_with_heading = tmp_path / "with_heading.md"
doc_with_heading.write_text("# My Title\n\nContent here")
title = handler._extract_doc_title(doc_with_heading)
assert title == "My Title"
# Test without heading
doc_without_heading = tmp_path / "no-heading.md"
doc_without_heading.write_text("Just content, no heading")
title = handler._extract_doc_title(doc_without_heading)
assert title == "No Heading" # Uses filename
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_markdown_chunking(tmp_path):
"""Test DocumentationHandler chunks markdown by headings."""
handler = DocumentationHandler()
# Test document with multiple sections
doc_with_sections = tmp_path / "sections.md"
doc_with_sections.write_text(
"# Document Title\n\n"
"Intro paragraph.\n\n"
"## Section One\n\n"
"Content for section one.\n\n"
"## Section Two\n\n"
"Content for section two.\n"
)
sections = handler._chunk_markdown_by_headings(doc_with_sections)
# Should have 3 sections: intro (with doc title), section one, section two
assert len(sections) == 3
assert sections[0].title == "Document Title"
assert sections[0].index == 0
assert "Intro paragraph" in sections[0].content
assert sections[1].title == "Section One"
assert sections[1].index == 1
assert "Content for section one" in sections[1].content
assert sections[2].title == "Section Two"
assert sections[2].index == 2
assert "Content for section two" in sections[2].content
# Test document without headings
doc_no_sections = tmp_path / "no-sections.md"
doc_no_sections.write_text("Just plain content without any headings.")
sections = handler._chunk_markdown_by_headings(doc_no_sections)
assert len(sections) == 1
assert sections[0].index == 0
assert "Just plain content" in sections[0].content
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_section_content_ids():
"""Test DocumentationHandler creates and parses section content IDs."""
handler = DocumentationHandler()
# Test making content ID
content_id = handler._make_section_content_id("docs/guide.md", 2)
assert content_id == "docs/guide.md::2"
# Test parsing content ID
doc_path, section_index = handler._parse_section_content_id("docs/guide.md::2")
assert doc_path == "docs/guide.md"
assert section_index == 2
# Test parsing legacy format (no section index)
doc_path, section_index = handler._parse_section_content_id("docs/old-format.md")
assert doc_path == "docs/old-format.md"
assert section_index == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_content_handlers_registry():
"""Test all content types are registered."""
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
assert ContentType.BLOCK in CONTENT_HANDLERS
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
handler = BlockHandler()
# Mock block with minimal attributes
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
assert len(items) == 1
assert items[0].searchable_text == "Minimal Block"
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_skips_failed_blocks():
"""Test BlockHandler skips blocks that fail to instantiate."""
handler = BlockHandler()
# Mock one good block and one bad block
good_block = MagicMock()
good_instance = MagicMock()
good_instance.name = "Good Block"
good_instance.description = "Works fine"
good_instance.categories = []
good_block.return_value = good_instance
bad_block = MagicMock()
bad_block.side_effect = Exception("Instantiation failed")
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
with patch(
"backend.data.block.get_blocks",
return_value=mock_blocks,
):
with patch(
"backend.api.features.store.content_handlers.query_raw_with_schema",
return_value=[],
):
items = await handler.get_missing_items(batch_size=10)
# Should only get the good block
assert len(items) == 1
assert items[0].content_id == "good-block"
@pytest.mark.asyncio(loop_scope="session")
async def test_documentation_handler_missing_docs_directory():
"""Test DocumentationHandler handles missing docs directory gracefully."""
handler = DocumentationHandler()
# Mock _get_docs_root to return non-existent path
fake_path = Path("/nonexistent/docs")
with patch.object(handler, "_get_docs_root", return_value=fake_path):
items = await handler.get_missing_items(batch_size=10)
assert items == []
stats = await handler.get_stats()
assert stats["total"] == 0
assert stats["with_embeddings"] == 0
assert stats["without_embeddings"] == 0

View File

@@ -1,8 +1,7 @@
import asyncio
import logging
import typing
from datetime import datetime, timezone
from typing import Literal
from typing import Any, Literal
import fastapi
import prisma.enums
@@ -10,7 +9,7 @@ import prisma.errors
import prisma.models
import prisma.types
from backend.data.db import query_raw_with_schema, transaction
from backend.data.db import transaction
from backend.data.graph import (
GraphMeta,
GraphModel,
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
from . import exceptions as store_exceptions
from . import model as store_model
from .embeddings import ensure_embedding
from .hybrid_search import hybrid_search
logger = logging.getLogger(__name__)
settings = Settings()
@@ -50,128 +51,77 @@ async def get_store_agents(
page_size: int = 20,
) -> store_model.StoreAgentsResponse:
"""
Get PUBLIC store agents from the StoreAgent view
Get PUBLIC store agents from the StoreAgent view.
Search behavior:
- With search_query: Uses hybrid search (semantic + lexical)
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
- Rationale: User-facing endpoint prioritizes availability over accuracy
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
"""
logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
search_used_hybrid = False
store_agents: list[store_model.StoreAgent] = []
agents: list[dict[str, Any]] = []
total = 0
total_pages = 0
try:
# If search_query is provided, use full-text search
# If search_query is provided, use hybrid search (embeddings + tsvector)
if search_query:
offset = (page - 1) * page_size
# Try hybrid search combining semantic and lexical signals
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
try:
agents, total = await hybrid_search(
query=search_query,
featured=featured,
creators=creators,
category=category,
sorted_by="relevance", # Use hybrid scoring for relevance
page=page,
page_size=page_size,
)
search_used_hybrid = True
except Exception as e:
# Log error but fall back to lexical search for better UX
logger.error(
f"Hybrid search failed (likely OpenAI unavailable), "
f"falling back to lexical search: {e}"
)
# search_used_hybrid remains False, will use fallback path below
# Whitelist allowed order_by columns
ALLOWED_ORDER_BY = {
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank ASC",
"updated_at": "updated_at DESC, rank DESC",
}
# Convert hybrid search results (dict format) if hybrid succeeded
if search_used_hybrid:
total_pages = (total + page_size - 1) // page_size
store_agents: list[store_model.StoreAgent] = []
for agent in agents:
try:
store_agent = store_model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(
f"Error parsing Store agent from hybrid search results: {e}"
)
continue
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC"
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_query] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured:
where_parts.append("featured = true")
if creators and creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(creators)
param_index += 1
if category and category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(category)
param_index += 1
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f"""
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank_cd(search, query) AS rank
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
"""
# Count query for pagination - only uses search term parameter
count_query = f"""
SELECT COUNT(*) as count
FROM {{schema_prefix}}"StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
"""
# Execute both queries with parameters
agents = await query_raw_with_schema(sql_query, *params)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await query_raw_with_schema(count_query, *count_params)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
store_agents: list[store_model.StoreAgent] = []
for agent in agents:
try:
store_agent = store_model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(f"Error parsing Store agent from search results: {e}")
continue
else:
# Non-search query path (original logic)
if not search_used_hybrid:
# Fallback path - use basic search or no search
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
@@ -180,6 +130,14 @@ async def get_store_agents(
if category:
where_clause["categories"] = {"has": category}
# Add basic text search if search_query provided but hybrid failed
if search_query:
where_clause["OR"] = [
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
{"description": {"contains": search_query, "mode": "insensitive"}},
]
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
@@ -188,7 +146,7 @@ async def get_store_agents(
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
agents = await prisma.models.StoreAgent.prisma().find_many(
db_agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
@@ -199,7 +157,7 @@ async def get_store_agents(
total_pages = (total + page_size - 1) // page_size
store_agents: list[store_model.StoreAgent] = []
for agent in agents:
for agent in db_agents:
try:
# Create the StoreAgent object safely
store_agent = store_model.StoreAgent(
@@ -614,6 +572,7 @@ async def get_store_submissions(
submission_models = []
for sub in submissions:
submission_model = store_model.StoreSubmission(
listing_id=sub.listing_id,
agent_id=sub.agent_id,
agent_version=sub.agent_version,
name=sub.name,
@@ -667,35 +626,48 @@ async def delete_store_submission(
submission_id: str,
) -> bool:
"""
Delete a store listing submission as the submitting user.
Delete a store submission version as the submitting user.
Args:
user_id: ID of the authenticated user
submission_id: ID of the submission to be deleted
submission_id: StoreListingVersion ID to delete
Returns:
bool: True if the submission was successfully deleted, False otherwise
bool: True if successfully deleted
"""
logger.debug(f"Deleting store submission {submission_id} for user {user_id}")
try:
# Verify the submission belongs to this user
submission = await prisma.models.StoreListing.prisma().find_first(
where={"agentGraphId": submission_id, "owningUserId": user_id}
# Find the submission version with ownership check
version = await prisma.models.StoreListingVersion.prisma().find_first(
where={"id": submission_id}, include={"StoreListing": True}
)
if not submission:
logger.warning(f"Submission not found for user {user_id}: {submission_id}")
raise store_exceptions.SubmissionNotFoundError(
f"Submission not found for this user. User ID: {user_id}, Submission ID: {submission_id}"
if (
not version
or not version.StoreListing
or version.StoreListing.owningUserId != user_id
):
raise store_exceptions.SubmissionNotFoundError("Submission not found")
# Prevent deletion of approved submissions
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
raise store_exceptions.InvalidOperationError(
"Cannot delete approved submissions"
)
# Delete the submission
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
# Delete the version
await prisma.models.StoreListingVersion.prisma().delete(
where={"id": version.id}
)
# Clean up empty listing if this was the last version
remaining = await prisma.models.StoreListingVersion.prisma().count(
where={"storeListingId": version.storeListingId}
)
if remaining == 0:
await prisma.models.StoreListing.prisma().delete(
where={"id": version.storeListingId}
)
return True
except Exception as e:
@@ -759,9 +731,15 @@ async def create_store_submission(
logger.warning(
f"Agent not found for user {user_id}: {agent_id} v{agent_version}"
)
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Provide more user-friendly error message when agent_id is empty
if not agent_id or agent_id.strip() == "":
raise store_exceptions.AgentNotFoundError(
"No agent selected. Please select an agent before submitting to the store."
)
else:
raise store_exceptions.AgentNotFoundError(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Check if listing already exists for this agent
existing_listing = await prisma.models.StoreListing.prisma().find_first(
@@ -833,6 +811,7 @@ async def create_store_submission(
logger.debug(f"Created store listing for agent {agent_id}")
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -944,81 +923,56 @@ async def edit_store_submission(
# Currently we are not allowing user to update the agent associated with a submission
# If we allow it in future, then we need a check here to verify the agent belongs to this user.
# Check if we can edit this submission
if current_version.submissionStatus == prisma.enums.SubmissionStatus.REJECTED:
# Only allow editing of PENDING submissions
if current_version.submissionStatus != prisma.enums.SubmissionStatus.PENDING:
raise store_exceptions.InvalidOperationError(
"Cannot edit a rejected submission"
)
# For APPROVED submissions, we need to create a new version
if current_version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
# Create a new version for the existing listing
return await create_store_version(
user_id=user_id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
store_listing_id=current_version.storeListingId,
name=name,
video_url=video_url,
agent_output_demo_url=agent_output_demo_url,
image_urls=image_urls,
description=description,
sub_heading=sub_heading,
categories=categories,
changes_summary=changes_summary,
recommended_schedule_cron=recommended_schedule_cron,
instructions=instructions,
f"Cannot edit a {current_version.submissionStatus.value.lower()} submission. Only pending submissions can be edited."
)
# For PENDING submissions, we can update the existing version
elif current_version.submissionStatus == prisma.enums.SubmissionStatus.PENDING:
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
categories=categories,
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
# Update the existing version
updated_version = await prisma.models.StoreListingVersion.prisma().update(
where={"id": store_listing_version_id},
data=prisma.types.StoreListingVersionUpdateInput(
name=name,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
subHeading=sub_heading,
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
instructions=instructions,
),
)
else:
raise store_exceptions.InvalidOperationError(
f"Cannot edit submission with status: {current_version.submissionStatus}"
)
logger.debug(
f"Updated existing version {store_listing_version_id} for agent {current_version.agentGraphId}"
)
if not updated_version:
raise DatabaseError("Failed to update store listing version")
return store_model.StoreSubmission(
listing_id=current_version.StoreListing.id,
agent_id=current_version.agentGraphId,
agent_version=current_version.agentGraphVersion,
name=name,
sub_heading=sub_heading,
slug=current_version.StoreListing.slug,
description=description,
instructions=instructions,
image_urls=image_urls,
date_submitted=updated_version.submittedAt or updated_version.createdAt,
status=updated_version.submissionStatus,
runs=0,
rating=0.0,
store_listing_version_id=updated_version.id,
changes_summary=changes_summary,
video_url=video_url,
categories=categories,
version=updated_version.version,
)
except (
store_exceptions.SubmissionNotFoundError,
@@ -1097,38 +1051,78 @@ async def create_store_version(
f"Agent not found for this user. User ID: {user_id}, Agent ID: {agent_id}, Version: {agent_version}"
)
# Get the latest version number
latest_version = listing.Versions[0] if listing.Versions else None
next_version = (latest_version.version + 1) if latest_version else 1
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma().create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
# Check if there's already a PENDING submission for this agent (any version)
existing_pending_submission = (
await prisma.models.StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
storeListingId=store_listing_id,
agentGraphId=agent_id,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
isDeleted=False,
)
)
)
# Handle existing pending submission and create new one atomically
async with transaction() as tx:
# Get the latest version number first
latest_listing = await prisma.models.StoreListing.prisma(tx).find_first(
where=prisma.types.StoreListingWhereInput(
id=store_listing_id, owningUserId=user_id
),
include={"Versions": {"order_by": {"version": "desc"}, "take": 1}},
)
if not latest_listing:
raise store_exceptions.ListingNotFoundError(
f"Store listing not found. User ID: {user_id}, Listing ID: {store_listing_id}"
)
latest_version = (
latest_listing.Versions[0] if latest_listing.Versions else None
)
next_version = (latest_version.version + 1) if latest_version else 1
# If there's an existing pending submission, delete it atomically before creating new one
if existing_pending_submission:
logger.info(
f"Found existing PENDING submission for agent {agent_id} (was v{existing_pending_submission.agentGraphVersion}, now v{agent_version}), replacing existing submission instead of creating duplicate"
)
await prisma.models.StoreListingVersion.prisma(tx).delete(
where={"id": existing_pending_submission.id}
)
logger.debug(
f"Deleted existing pending submission {existing_pending_submission.id}"
)
# Create a new version for the existing listing
new_version = await prisma.models.StoreListingVersion.prisma(tx).create(
data=prisma.types.StoreListingVersionCreateInput(
version=next_version,
agentGraphId=agent_id,
agentGraphVersion=agent_version,
name=name,
videoUrl=video_url,
agentOutputDemoUrl=agent_output_demo_url,
imageUrls=image_urls,
description=description,
instructions=instructions,
categories=categories,
subHeading=sub_heading,
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
submittedAt=datetime.now(),
changesSummary=changes_summary,
recommendedScheduleCron=recommended_schedule_cron,
storeListingId=store_listing_id,
)
)
logger.debug(
f"Created new version for listing {store_listing_id} of agent {agent_id}"
)
# Return submission details
return store_model.StoreSubmission(
listing_id=listing.id,
agent_id=agent_id,
agent_version=agent_version,
name=name,
@@ -1541,7 +1535,7 @@ async def review_store_submission(
)
# Update the AgentGraph with store listing data
await prisma.models.AgentGraph.prisma().update(
await prisma.models.AgentGraph.prisma(tx).update(
where={
"graphVersionId": {
"id": store_listing_version.agentGraphId,
@@ -1556,6 +1550,23 @@ async def review_store_submission(
},
)
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
embedding_success = await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
sub_heading=store_listing_version.subHeading,
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id},
data={
@@ -1708,15 +1719,12 @@ async def review_store_submission(
# Convert to Pydantic model for consistency
return store_model.StoreSubmission(
listing_id=(submission.StoreListing.id if submission.StoreListing else ""),
agent_id=submission.agentGraphId,
agent_version=submission.agentGraphVersion,
name=submission.name,
sub_heading=submission.subHeading,
slug=(
submission.StoreListing.slug
if hasattr(submission, "storeListing") and submission.StoreListing
else ""
),
slug=(submission.StoreListing.slug if submission.StoreListing else ""),
description=submission.description,
instructions=submission.instructions,
image_urls=submission.imageUrls or [],
@@ -1818,9 +1826,7 @@ async def get_admin_listings_with_versions(
where = prisma.types.StoreListingWhereInput(**where_dict)
include = prisma.types.StoreListingInclude(
Versions=prisma.types.FindManyStoreListingVersionArgsFromStoreListing(
order_by=prisma.types._StoreListingVersion_version_OrderByInput(
version="desc"
)
order_by={"version": "desc"}
),
OwningUser=True,
)
@@ -1845,6 +1851,7 @@ async def get_admin_listings_with_versions(
# If we have versions, turn them into StoreSubmission models
for version in listing.Versions or []:
version_model = store_model.StoreSubmission(
listing_id=listing.id,
agent_id=version.agentGraphId,
agent_version=version.agentGraphVersion,
name=version.name,

View File

@@ -0,0 +1,962 @@
"""
Unified Content Embeddings Service
Handles generation and storage of OpenAI embeddings for all content types
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
"""
import asyncio
import logging
import time
from typing import Any
import prisma
from prisma.enums import ContentType
from tiktoken import encoding_for_model
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
from backend.util.clients import get_openai_client
from backend.util.json import dumps
logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# Embedding dimension for the model above
# text-embedding-3-small: 1536, text-embedding-3-large: 3072
EMBEDDING_DIM = 1536
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
EMBEDDING_MAX_TOKENS = 8191
def build_searchable_text(
name: str,
description: str,
sub_heading: str,
categories: list[str],
) -> str:
"""
Build searchable text from listing version fields.
Combines relevant fields into a single string for embedding.
"""
parts = []
# Name is important - include it
if name:
parts.append(name)
# Sub-heading provides context
if sub_heading:
parts.append(sub_heading)
# Description is the main content
if description:
parts.append(description)
# Categories help with semantic matching
if categories:
parts.append(" ".join(categories))
return " ".join(parts)
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
"""
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
version_id: str,
embedding: list[float],
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the database.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
"""
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
metadata=None,
user_id=None, # Store agents are public
tx=tx,
)
async def store_content_embedding(
content_type: ContentType,
content_id: str,
embedding: list[float],
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Store embedding in the unified content embeddings table.
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
"""
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
set_public_search_path=True,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
"""
Retrieve embedding record for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
"""
result = await get_content_embedding(
ContentType.STORE_AGENT, version_id, user_id=None
)
if result:
# Transform to old format for backward compatibility
return {
"storeListingVersionId": result["contentId"],
"embedding": result["embedding"],
"createdAt": result["createdAt"],
"updatedAt": result["updatedAt"],
}
return None
async def get_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> dict[str, Any] | None:
"""
Retrieve embedding record for any content type.
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
"""
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
set_public_search_path=True,
)
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding(
version_id: str,
name: str,
description: str,
sub_heading: str,
categories: list[str],
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for the listing version.
Creates embedding if missing. Use force=True to regenerate.
Backward-compatible wrapper for store listings.
Args:
version_id: The StoreListingVersion ID
name: Agent name
description: Agent description
sub_heading: Agent sub-heading
categories: Agent categories
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
"""
Delete embedding for a listing version.
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
"""
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
async def delete_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> bool:
"""
Delete embedding for any content type.
New function for unified content embedding deletion.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
Args:
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
content_id: The unique identifier for the content
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
deleting embeddings belonging to other users.
Returns:
True if deletion succeeded, False otherwise
"""
try:
client = prisma.get_client()
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = $2
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
client=client,
)
user_str = f" (user: {user_id})" if user_id else ""
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
return True
except Exception as e:
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding_stats() -> dict[str, Any]:
"""
Get statistics about embedding coverage for all content types.
Returns stats per content type and overall totals.
"""
try:
stats_by_type = {}
total_items = 0
total_with_embeddings = 0
total_without_embeddings = 0
# Aggregate stats from all handlers
for content_type, handler in CONTENT_HANDLERS.items():
try:
stats = await handler.get_stats()
stats_by_type[content_type.value] = {
"total": stats["total"],
"with_embeddings": stats["with_embeddings"],
"without_embeddings": stats["without_embeddings"],
"coverage_percent": (
round(stats["with_embeddings"] / stats["total"] * 100, 1)
if stats["total"] > 0
else 0
),
}
total_items += stats["total"]
total_with_embeddings += stats["with_embeddings"]
total_without_embeddings += stats["without_embeddings"]
except Exception as e:
logger.error(f"Failed to get stats for {content_type.value}: {e}")
stats_by_type[content_type.value] = {
"total": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
"error": str(e),
}
return {
"by_type": stats_by_type,
"totals": {
"total": total_items,
"with_embeddings": total_with_embeddings,
"without_embeddings": total_without_embeddings,
"coverage_percent": (
round(total_with_embeddings / total_items * 100, 1)
if total_items > 0
else 0
),
},
}
except Exception as e:
logger.error(f"Failed to get embedding stats: {e}")
return {
"by_type": {},
"totals": {
"total": 0,
"with_embeddings": 0,
"without_embeddings": 0,
"coverage_percent": 0,
},
"error": str(e),
}
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for approved listings that don't have them.
BACKWARD COMPATIBILITY: Maintained for existing usage.
This now delegates to backfill_all_content_types() to process all content types.
Args:
batch_size: Number of embeddings to generate per content type
Returns:
Dict with success/failure counts aggregated across all content types
"""
# Delegate to the new generic backfill system
result = await backfill_all_content_types(batch_size)
# Return in the old format for backward compatibility
return result["totals"]
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
"""
Generate embeddings for all content types using registered handlers.
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
This ensures foundational content (blocks) are searchable first.
Args:
batch_size: Number of embeddings to generate per content type
Returns:
Dict with stats per content type and overall totals
"""
results_by_type = {}
total_processed = 0
total_success = 0
total_failed = 0
# Process content types in explicit order
processing_order = [
ContentType.BLOCK,
ContentType.STORE_AGENT,
ContentType.DOCUMENTATION,
]
for content_type in processing_order:
handler = CONTENT_HANDLERS.get(content_type)
if not handler:
logger.warning(f"No handler registered for {content_type.value}")
continue
try:
logger.info(f"Processing {content_type.value} content type...")
# Get missing items from handler
missing_items = await handler.get_missing_items(batch_size)
if not missing_items:
results_by_type[content_type.value] = {
"processed": 0,
"success": 0,
"failed": 0,
"message": "No missing embeddings",
}
continue
# Process embeddings concurrently for better performance
embedding_tasks = [
ensure_content_embedding(
content_type=item.content_type,
content_id=item.content_id,
searchable_text=item.searchable_text,
metadata=item.metadata,
user_id=item.user_id,
)
for item in missing_items
]
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
success = sum(1 for result in results if result is True)
failed = len(results) - success
results_by_type[content_type.value] = {
"processed": len(missing_items),
"success": success,
"failed": failed,
"message": f"Backfilled {success} embeddings, {failed} failed",
}
total_processed += len(missing_items)
total_success += success
total_failed += failed
logger.info(
f"{content_type.value}: processed {len(missing_items)}, "
f"success {success}, failed {failed}"
)
except Exception as e:
logger.error(f"Failed to process {content_type.value}: {e}")
results_by_type[content_type.value] = {
"processed": 0,
"success": 0,
"failed": 0,
"error": str(e),
}
return {
"by_type": results_by_type,
"totals": {
"processed": total_processed,
"success": total_success,
"failed": total_failed,
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
},
}
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
"""
return await generate_embedding(query)
def embedding_to_vector_string(embedding: list[float]) -> str:
"""Convert embedding list to PostgreSQL vector string format."""
return "[" + ",".join(str(x) for x in embedding) + "]"
async def ensure_content_embedding(
content_type: ContentType,
content_id: str,
searchable_text: str,
metadata: dict | None = None,
user_id: str | None = None,
force: bool = False,
tx: prisma.Prisma | None = None,
) -> bool:
"""
Ensure an embedding exists for any content type.
Generic function for creating embeddings for store agents, blocks, docs, etc.
Args:
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
content_id: Unique identifier for the content
searchable_text: Combined text for embedding generation
metadata: Optional metadata to store with embedding
force: Force regeneration even if embedding exists
tx: Optional transaction client
Returns:
True if embedding exists/was created, False on failure
"""
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
"""
Clean up embeddings for content that no longer exists or is no longer valid.
Compares current content with embeddings in database and removes orphaned records:
- STORE_AGENT: Removes embeddings for rejected/deleted store listings
- BLOCK: Removes embeddings for blocks no longer registered
- DOCUMENTATION: Removes embeddings for deleted doc files
Returns:
Dict with cleanup statistics per content type
"""
results_by_type = {}
total_deleted = 0
# Cleanup orphaned embeddings for all content types
cleanup_types = [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]
for content_type in cleanup_types:
try:
handler = CONTENT_HANDLERS.get(content_type)
if not handler:
logger.warning(f"No handler registered for {content_type}")
results_by_type[content_type.value] = {
"deleted": 0,
"error": "No handler registered",
}
continue
# Get all current content IDs from handler
if content_type == ContentType.STORE_AGENT:
# Get IDs of approved store listing versions from non-deleted listings
valid_agents = await query_raw_with_schema(
"""
SELECT slv.id
FROM {schema_prefix}"StoreListingVersion" slv
JOIN {schema_prefix}"StoreListing" sl ON slv."storeListingId" = sl.id
WHERE slv."submissionStatus" = 'APPROVED'
AND slv."isDeleted" = false
AND sl."isDeleted" = false
""",
)
current_ids = {row["id"] for row in valid_agents}
elif content_type == ContentType.BLOCK:
from backend.data.block import get_blocks
current_ids = set(get_blocks().keys())
elif content_type == ContentType.DOCUMENTATION:
# Use DocumentationHandler to get section-based content IDs
from backend.api.features.store.content_handlers import (
DocumentationHandler,
)
doc_handler = CONTENT_HANDLERS.get(ContentType.DOCUMENTATION)
if isinstance(doc_handler, DocumentationHandler):
docs_root = doc_handler._get_docs_root()
if docs_root.exists():
current_ids = doc_handler._get_all_section_content_ids(
docs_root
)
else:
current_ids = set()
else:
current_ids = set()
else:
# Skip unknown content types to avoid accidental deletion
logger.warning(
f"Skipping cleanup for unknown content type: {content_type}"
)
results_by_type[content_type.value] = {
"deleted": 0,
"error": "Unknown content type - skipped for safety",
}
continue
# Get all embedding IDs from database
db_embeddings = await query_raw_with_schema(
"""
SELECT "contentId"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
""",
content_type,
)
db_ids = {row["contentId"] for row in db_embeddings}
# Find orphaned embeddings (in DB but not in current content)
orphaned_ids = db_ids - current_ids
if not orphaned_ids:
logger.info(f"{content_type.value}: No orphaned embeddings found")
results_by_type[content_type.value] = {
"deleted": 0,
"message": "No orphaned embeddings",
}
continue
# Delete orphaned embeddings in batch for better performance
orphaned_list = list(orphaned_ids)
try:
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = ANY($2::text[])
""",
content_type,
orphaned_list,
)
deleted = len(orphaned_list)
except Exception as e:
logger.error(f"Failed to batch delete orphaned embeddings: {e}")
deleted = 0
logger.info(
f"{content_type.value}: Deleted {deleted}/{len(orphaned_ids)} orphaned embeddings"
)
results_by_type[content_type.value] = {
"deleted": deleted,
"orphaned": len(orphaned_ids),
"message": f"Deleted {deleted} orphaned embeddings",
}
total_deleted += deleted
except Exception as e:
logger.error(f"Failed to cleanup {content_type.value}: {e}")
results_by_type[content_type.value] = {
"deleted": 0,
"error": str(e),
}
return {
"by_type": results_by_type,
"totals": {
"deleted": total_deleted,
"message": f"Deleted {total_deleted} orphaned embeddings",
},
}
async def semantic_search(
query: str,
content_types: list[ContentType] | None = None,
user_id: str | None = None,
limit: int = 20,
min_similarity: float = 0.5,
) -> list[dict[str, Any]]:
"""
Semantic search across content types using embeddings.
Performs vector similarity search on UnifiedContentEmbedding table.
Used directly for blocks/docs/library agents, or as the semantic component
within hybrid_search for store agents.
If embedding generation fails, falls back to lexical search on searchableText.
Args:
query: Search query string
content_types: List of ContentType to search. Defaults to [BLOCK, STORE_AGENT, DOCUMENTATION]
user_id: Optional user ID for searching private content (library agents)
limit: Maximum number of results to return (default: 20)
min_similarity: Minimum cosine similarity threshold (0-1, default: 0.5)
Returns:
List of search results with the following structure:
[
{
"content_id": str,
"content_type": str, # "BLOCK", "STORE_AGENT", "DOCUMENTATION", or "LIBRARY_AGENT"
"searchable_text": str,
"metadata": dict,
"similarity": float, # Cosine similarity score (0-1)
},
...
]
Examples:
# Search blocks only
results = await semantic_search("calculate", content_types=[ContentType.BLOCK])
# Search blocks and documentation
results = await semantic_search(
"how to use API",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION]
)
# Search all public content (default)
results = await semantic_search("AI agent")
# Search user's library agents
results = await semantic_search(
"my custom agent",
content_types=[ContentType.LIBRARY_AGENT],
user_id="user123"
)
"""
# Default to searching all public content types
if content_types is None:
content_types = [
ContentType.BLOCK,
ContentType.STORE_AGENT,
ContentType.DOCUMENTATION,
]
# Validate inputs
if not content_types:
return [] # Empty content_types would cause invalid SQL (IN ())
query = query.strip()
if not query:
return []
if limit < 1:
limit = 1
if limit > 100:
limit = 100
# Generate query embedding
query_embedding = await embed_query(query)
if query_embedding is not None:
# Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding)
# Build params in order: limit, then user_id (if provided), then content types
params: list[Any] = [limit]
user_filter = ""
if user_id is not None:
user_filter = 'AND "userId" = ${}'.format(len(params) + 1)
params.append(user_id)
# Add content type parameters and build placeholders dynamically
content_type_start_idx = len(params) + 1
content_type_placeholders = ", ".join(
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
for i in range(len(content_types))
)
params.extend([ct.value for ct in content_types])
sql = f"""
SELECT
"contentId" as content_id,
"contentType" as content_type,
"searchableText" as searchable_text,
metadata,
1 - (embedding <=> '{embedding_str}'::vector) as similarity
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
WHERE "contentType" IN ({content_type_placeholders})
{user_filter}
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
ORDER BY similarity DESC
LIMIT $1
"""
params.append(min_similarity)
try:
results = await query_raw_with_schema(
sql, *params, set_public_search_path=True
)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
# Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit]
user_filter = ""
if user_id is not None:
user_filter = 'AND "userId" = ${}'.format(len(params_lexical) + 1)
params_lexical.append(user_id)
# Add content type parameters and build placeholders dynamically
content_type_start_idx = len(params_lexical) + 1
content_type_placeholders_lexical = ", ".join(
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
for i in range(len(content_types))
)
params_lexical.extend([ct.value for ct in content_types])
sql_lexical = f"""
SELECT
"contentId" as content_id,
"contentType" as content_type,
"searchableText" as searchable_text,
metadata,
0.0 as similarity
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
WHERE "contentType" IN ({content_type_placeholders_lexical})
{user_filter}
AND "searchableText" ILIKE ${len(params_lexical) + 1}
ORDER BY "updatedAt" DESC
LIMIT $1
"""
params_lexical.append(f"%{query}%")
try:
results = await query_raw_with_schema(
sql_lexical, *params_lexical, set_public_search_path=True
)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": 0.0, # Lexical search doesn't provide similarity
}
for row in results
]
except Exception as e:
logger.error(f"Lexical search failed: {e}")
return []

View File

@@ -0,0 +1,666 @@
"""
End-to-end database tests for embeddings and hybrid search.
These tests hit the actual database to verify SQL queries work correctly.
Tests cover:
1. Embedding storage (store_content_embedding)
2. Embedding retrieval (get_content_embedding)
3. Embedding deletion (delete_content_embedding)
4. Unified hybrid search across content types
5. Store agent hybrid search
"""
import uuid
from typing import AsyncGenerator
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.embeddings import EMBEDDING_DIM
from backend.api.features.store.hybrid_search import (
hybrid_search,
unified_hybrid_search,
)
# ============================================================================
# Test Fixtures
# ============================================================================
@pytest.fixture
def test_content_id() -> str:
"""Generate unique content ID for test isolation."""
return f"test-content-{uuid.uuid4()}"
@pytest.fixture
def test_user_id() -> str:
"""Generate unique user ID for test isolation."""
return f"test-user-{uuid.uuid4()}"
@pytest.fixture
def mock_embedding() -> list[float]:
"""Generate a mock embedding vector."""
# Create a normalized embedding vector
import math
raw = [float(i % 10) / 10.0 for i in range(EMBEDDING_DIM)]
# Normalize to unit length (required for cosine similarity)
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
def similar_embedding() -> list[float]:
"""Generate an embedding similar to mock_embedding."""
import math
# Similar but slightly different values
raw = [float(i % 10) / 10.0 + 0.01 for i in range(EMBEDDING_DIM)]
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
def different_embedding() -> list[float]:
"""Generate an embedding very different from mock_embedding."""
import math
# Reversed pattern to be maximally different
raw = [float((EMBEDDING_DIM - i) % 10) / 10.0 for i in range(EMBEDDING_DIM)]
magnitude = math.sqrt(sum(x * x for x in raw))
return [x / magnitude for x in raw]
@pytest.fixture
async def cleanup_embeddings(
server,
) -> AsyncGenerator[list[tuple[ContentType, str, str | None]], None]:
"""
Fixture that tracks created embeddings and cleans them up after tests.
Yields a list to which tests can append (content_type, content_id, user_id) tuples.
"""
created_embeddings: list[tuple[ContentType, str, str | None]] = []
yield created_embeddings
# Cleanup all created embeddings
for content_type, content_id, user_id in created_embeddings:
try:
await embeddings.delete_content_embedding(content_type, content_id, user_id)
except Exception:
pass # Ignore cleanup errors
# ============================================================================
# store_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_store_agent(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for STORE_AGENT content type."""
# Track for cleanup
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="AI assistant for productivity tasks",
metadata={"name": "Test Agent", "categories": ["productivity"]},
user_id=None, # Store agents are public
)
assert result is True
# Verify it was stored
stored = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentId"] == test_content_id
assert stored["contentType"] == "STORE_AGENT"
assert stored["searchableText"] == "AI assistant for productivity tasks"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_block(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for BLOCK content type."""
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="HTTP request block for API calls",
metadata={"name": "HTTP Request Block"},
user_id=None, # Blocks are public
)
assert result is True
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentType"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_documentation(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test storing embedding for DOCUMENTATION content type."""
cleanup_embeddings.append((ContentType.DOCUMENTATION, test_content_id, None))
result = await embeddings.store_content_embedding(
content_type=ContentType.DOCUMENTATION,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Getting started guide for AutoGPT platform",
metadata={"title": "Getting Started", "url": "/docs/getting-started"},
user_id=None, # Docs are public
)
assert result is True
stored = await embeddings.get_content_embedding(
ContentType.DOCUMENTATION, test_content_id, user_id=None
)
assert stored is not None
assert stored["contentType"] == "DOCUMENTATION"
@pytest.mark.asyncio(loop_scope="session")
async def test_store_content_embedding_upsert(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test that storing embedding twice updates instead of duplicates."""
cleanup_embeddings.append((ContentType.BLOCK, test_content_id, None))
# Store first time
result1 = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Original text",
metadata={"version": 1},
user_id=None,
)
assert result1 is True
# Store again with different text (upsert)
result2 = await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="Updated text",
metadata={"version": 2},
user_id=None,
)
assert result2 is True
# Verify only one record with updated text
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
assert stored["searchableText"] == "Updated text"
# ============================================================================
# get_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_get_content_embedding_not_found(server):
"""Test retrieving non-existent embedding returns None."""
result = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, "non-existent-id", user_id=None
)
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_get_content_embedding_with_metadata(
server,
test_content_id: str,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test that metadata is correctly stored and retrieved."""
cleanup_embeddings.append((ContentType.STORE_AGENT, test_content_id, None))
metadata = {
"name": "Test Agent",
"subHeading": "A test agent",
"categories": ["ai", "productivity"],
"customField": 123,
}
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="test",
metadata=metadata,
user_id=None,
)
stored = await embeddings.get_content_embedding(
ContentType.STORE_AGENT, test_content_id, user_id=None
)
assert stored is not None
assert stored["metadata"]["name"] == "Test Agent"
assert stored["metadata"]["categories"] == ["ai", "productivity"]
assert stored["metadata"]["customField"] == 123
# ============================================================================
# delete_content_embedding Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_content_embedding(
server,
test_content_id: str,
mock_embedding: list[float],
):
"""Test deleting embedding removes it from database."""
# Store embedding
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=test_content_id,
embedding=mock_embedding,
searchable_text="To be deleted",
metadata=None,
user_id=None,
)
# Verify it exists
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is not None
# Delete it
result = await embeddings.delete_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert result is True
# Verify it's gone
stored = await embeddings.get_content_embedding(
ContentType.BLOCK, test_content_id, user_id=None
)
assert stored is None
@pytest.mark.asyncio(loop_scope="session")
async def test_delete_content_embedding_not_found(server):
"""Test deleting non-existent embedding doesn't error."""
result = await embeddings.delete_content_embedding(
ContentType.BLOCK, "non-existent-id", user_id=None
)
# Should succeed even if nothing to delete
assert result is True
# ============================================================================
# unified_hybrid_search Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_finds_matching_content(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search finds content matching the query."""
# Create unique content IDs
agent_id = f"test-agent-{uuid.uuid4()}"
block_id = f"test-block-{uuid.uuid4()}"
doc_id = f"test-doc-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
cleanup_embeddings.append((ContentType.DOCUMENTATION, doc_id, None))
# Store embeddings for different content types
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=agent_id,
embedding=mock_embedding,
searchable_text="AI writing assistant for blog posts",
metadata={"name": "Writing Assistant"},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=block_id,
embedding=mock_embedding,
searchable_text="Text generation block for creative writing",
metadata={"name": "Text Generator"},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.DOCUMENTATION,
content_id=doc_id,
embedding=mock_embedding,
searchable_text="How to use writing blocks in AutoGPT",
metadata={"title": "Writing Guide"},
user_id=None,
)
# Search for "writing" - should find all three
results, total = await unified_hybrid_search(
query="writing",
page=1,
page_size=20,
)
# Should find at least our test content (may find others too)
content_ids = [r["content_id"] for r in results]
assert agent_id in content_ids or total >= 1 # Lexical search should find it
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_filter_by_content_type(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search can filter by content type."""
agent_id = f"test-agent-{uuid.uuid4()}"
block_id = f"test-block-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.STORE_AGENT, agent_id, None))
cleanup_embeddings.append((ContentType.BLOCK, block_id, None))
# Store both types with same searchable text
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=agent_id,
embedding=mock_embedding,
searchable_text="unique_search_term_xyz123",
metadata={},
user_id=None,
)
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=block_id,
embedding=mock_embedding,
searchable_text="unique_search_term_xyz123",
metadata={},
user_id=None,
)
# Search only for BLOCK type
results, total = await unified_hybrid_search(
query="unique_search_term_xyz123",
content_types=[ContentType.BLOCK],
page=1,
page_size=20,
)
# All results should be BLOCK type
for r in results:
assert r["content_type"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_empty_query(server):
"""Test unified search with empty query returns empty results."""
results, total = await unified_hybrid_search(
query="",
page=1,
page_size=20,
)
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_pagination(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search pagination works correctly."""
# Create multiple items
content_ids = []
for i in range(5):
content_id = f"test-pagination-{uuid.uuid4()}"
content_ids.append(content_id)
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text=f"pagination test item number {i}",
metadata={"index": i},
user_id=None,
)
# Get first page
page1_results, total1 = await unified_hybrid_search(
query="pagination test",
content_types=[ContentType.BLOCK],
page=1,
page_size=2,
)
# Get second page
page2_results, total2 = await unified_hybrid_search(
query="pagination test",
content_types=[ContentType.BLOCK],
page=2,
page_size=2,
)
# Total should be consistent
assert total1 == total2
# Pages should have different content (if we have enough results)
if len(page1_results) > 0 and len(page2_results) > 0:
page1_ids = {r["content_id"] for r in page1_results}
page2_ids = {r["content_id"] for r in page2_results}
# No overlap between pages
assert page1_ids.isdisjoint(page2_ids)
@pytest.mark.asyncio(loop_scope="session")
async def test_unified_hybrid_search_min_score_filtering(
server,
mock_embedding: list[float],
cleanup_embeddings: list,
):
"""Test unified search respects min_score threshold."""
content_id = f"test-minscore-{uuid.uuid4()}"
cleanup_embeddings.append((ContentType.BLOCK, content_id, None))
await embeddings.store_content_embedding(
content_type=ContentType.BLOCK,
content_id=content_id,
embedding=mock_embedding,
searchable_text="completely unrelated content about bananas",
metadata={},
user_id=None,
)
# Search with very high min_score - should filter out low relevance
results_high, _ = await unified_hybrid_search(
query="quantum computing algorithms",
content_types=[ContentType.BLOCK],
min_score=0.9, # Very high threshold
page=1,
page_size=20,
)
# Search with low min_score
results_low, _ = await unified_hybrid_search(
query="quantum computing algorithms",
content_types=[ContentType.BLOCK],
min_score=0.01, # Very low threshold
page=1,
page_size=20,
)
# High threshold should have fewer or equal results
assert len(results_high) <= len(results_low)
# ============================================================================
# hybrid_search (Store Agents) Tests
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_store_agents_sql_valid(server):
"""Test that hybrid_search SQL executes without errors."""
# This test verifies the SQL is syntactically correct
# even if no results are found
results, total = await hybrid_search(
query="test agent",
page=1,
page_size=20,
)
# Should not raise - verifies SQL is valid
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_with_filters(server):
"""Test hybrid_search with various filter options."""
# Test with all filter types
results, total = await hybrid_search(
query="productivity",
featured=True,
creators=["test-creator"],
category="productivity",
page=1,
page_size=10,
)
# Should not raise - verifies filter SQL is valid
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_hybrid_search_pagination(server):
"""Test hybrid_search pagination."""
# Page 1
results1, total1 = await hybrid_search(
query="agent",
page=1,
page_size=5,
)
# Page 2
results2, total2 = await hybrid_search(
query="agent",
page=2,
page_size=5,
)
# Verify SQL executes without error
assert isinstance(results1, list)
assert isinstance(results2, list)
assert isinstance(total1, int)
assert isinstance(total2, int)
# If page 1 has results, total should be > 0
# Note: total from page 2 may be 0 if no results on that page (COUNT(*) OVER limitation)
if results1:
assert total1 > 0
# ============================================================================
# SQL Validity Tests (verify queries don't break)
# ============================================================================
@pytest.mark.asyncio(loop_scope="session")
async def test_all_content_types_searchable(server):
"""Test that all content types can be searched without SQL errors."""
for content_type in [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]:
results, total = await unified_hybrid_search(
query="test",
content_types=[content_type],
page=1,
page_size=10,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_content_types_searchable(server):
"""Test searching multiple content types at once."""
results, total = await unified_hybrid_search(
query="test",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
page=1,
page_size=20,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
@pytest.mark.asyncio(loop_scope="session")
async def test_search_all_content_types_default(server):
"""Test searching all content types (default behavior)."""
results, total = await unified_hybrid_search(
query="test",
content_types=None, # Should search all
page=1,
page_size=20,
)
# Should not raise
assert isinstance(results, list)
assert isinstance(total, int)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,315 @@
"""
Integration tests for embeddings with schema handling.
These tests verify that embeddings operations work correctly across different database schemas.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.embeddings import EMBEDDING_DIM
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_store_content_embedding_with_schema():
"""Test storing embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_content_embedding_with_schema():
"""Test retrieving embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.query_raw.return_value = [
{
"contentType": "STORE_AGENT",
"contentId": "test-id",
"userId": None,
"embedding": "[0.1, 0.2]",
"searchableText": "test",
"metadata": {},
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
]
mock_get_client.return_value = mock_client
result = await embeddings.get_content_embedding(
ContentType.STORE_AGENT,
"test-id",
user_id=None,
)
# Verify the query was called
assert mock_client.query_raw.called
# Get the SQL query that was executed
call_args = mock_client.query_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is not None
assert result["contentId"] == "test-id"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_delete_content_embedding_with_schema():
"""Test deleting embeddings with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
result = await embeddings.delete_content_embedding(
ContentType.STORE_AGENT,
"test-id",
)
# Verify the query was called
assert mock_client.execute_raw.called
# Get the SQL query that was executed
call_args = mock_client.execute_raw.call_args
sql_query = call_args[0][0]
# Verify schema prefix is in the query
assert '"platform"."UnifiedContentEmbedding"' in sql_query
# Verify result
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_get_embedding_stats_with_schema():
"""Test embedding statistics with proper schema handling via content handlers."""
# Mock handler to return stats
mock_handler = MagicMock()
mock_handler.get_stats = AsyncMock(
return_value={
"total": 100,
"with_embeddings": 80,
"without_embeddings": 20,
}
)
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.get_embedding_stats()
# Verify handler was called
mock_handler.get_stats.assert_called_once()
# Verify new result structure
assert "by_type" in result
assert "totals" in result
assert result["totals"]["total"] == 100
assert result["totals"]["with_embeddings"] == 80
assert result["totals"]["without_embeddings"] == 20
assert result["totals"]["coverage_percent"] == 80.0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backfill_missing_embeddings_with_schema():
"""Test backfilling embeddings via content handlers."""
from backend.api.features.store.content_handlers import ContentItem
# Create mock content item
mock_item = ContentItem(
content_id="version-1",
content_type=ContentType.STORE_AGENT,
searchable_text="Test Agent Test description",
metadata={"name": "Test Agent"},
)
# Mock handler
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=[mock_item])
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
with patch(
"backend.api.features.store.embeddings.generate_embedding",
return_value=[0.1] * EMBEDDING_DIM,
):
with patch(
"backend.api.features.store.embeddings.store_content_embedding",
return_value=True,
):
result = await embeddings.backfill_missing_embeddings(batch_size=10)
# Verify handler was called
mock_handler.get_missing_items.assert_called_once_with(10)
# Verify results
assert result["processed"] == 1
assert result["success"] == 1
assert result["failed"] == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_ensure_content_embedding_with_schema():
"""Test ensuring embeddings exist with proper schema handling."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
# Simulate no existing embedding
mock_get.return_value = None
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1] * EMBEDDING_DIM
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.ensure_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
searchable_text="test text",
metadata={"test": "data"},
user_id=None,
force=False,
)
# Verify the flow
assert mock_get.called
assert mock_generate.called
assert mock_store.called
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_store_embedding():
"""Test backward compatibility wrapper for store_embedding."""
with patch(
"backend.api.features.store.embeddings.store_content_embedding"
) as mock_store:
mock_store.return_value = True
result = await embeddings.store_embedding(
version_id="test-version-id",
embedding=[0.1] * EMBEDDING_DIM,
tx=None,
)
# Verify it calls the new function with correct parameters
assert mock_store.called
call_args = mock_store.call_args
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
assert call_args[1]["content_id"] == "test-version-id"
assert call_args[1]["user_id"] is None
assert result is True
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_backward_compatibility_get_embedding():
"""Test backward compatibility wrapper for get_embedding."""
with patch(
"backend.api.features.store.embeddings.get_content_embedding"
) as mock_get:
mock_get.return_value = {
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"embedding": "[0.1, 0.2]",
"createdAt": "2024-01-01",
"updatedAt": "2024-01-01",
}
result = await embeddings.get_embedding("test-version-id")
# Verify it calls the new function
assert mock_get.called
# Verify it transforms to old format
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert "embedding" in result
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_schema_handling_error_cases():
"""Test error handling in schema-aware operations."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch("prisma.get_client") as mock_get_client:
mock_client = AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,407 @@
from unittest.mock import AsyncMock, MagicMock, patch
import prisma
import pytest
from prisma import Prisma
from prisma.enums import ContentType
from backend.api.features.store import embeddings
@pytest.fixture(autouse=True)
async def setup_prisma():
"""Setup Prisma client for tests."""
try:
Prisma()
except prisma.errors.ClientAlreadyRegisteredError:
pass
yield
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text():
"""Test searchable text building from listing fields."""
result = embeddings.build_searchable_text(
name="AI Assistant",
description="A helpful AI assistant for productivity",
sub_heading="Boost your productivity",
categories=["AI", "Productivity"],
)
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
assert result == expected
@pytest.mark.asyncio(loop_scope="session")
async def test_build_searchable_text_empty_fields():
"""Test searchable text building with empty fields."""
result = embeddings.build_searchable_text(
name="", description="Test description", sub_heading="", categories=[]
)
assert result == "Test description"
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_success():
"""Test successful embedding generation."""
# Mock OpenAI response
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is not None
assert len(result) == embeddings.EMBEDDING_DIM
assert result[0] == 0.1
mock_client.embeddings.create.assert_called_once_with(
model="text-embedding-3-small", input="test text"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_no_api_key():
"""Test embedding generation without API key."""
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = None
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_api_error():
"""Test embedding generation with API error."""
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_generate_embedding_text_truncation():
"""Test that long text is properly truncated using tiktoken."""
from tiktoken import encoding_for_model
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.data = [MagicMock()]
mock_response.data[0].embedding = [0.1] * embeddings.EMBEDDING_DIM
# Use AsyncMock for async embeddings.create method
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
# Patch at the point of use in embeddings.py
with patch(
"backend.api.features.store.embeddings.get_openai_client"
) as mock_get_client:
mock_get_client.return_value = mock_client
# Create text that will exceed 8191 tokens
# Use varied characters to ensure token-heavy text: each word is ~1 token
words = [f"word{i}" for i in range(10000)]
long_text = " ".join(words) # ~10000 tokens
await embeddings.generate_embedding(long_text)
# Verify text was truncated to 8191 tokens
call_args = mock_client.embeddings.create.call_args
truncated_text = call_args.kwargs["input"]
# Count actual tokens in truncated text
enc = encoding_for_model("text-embedding-3-small")
actual_tokens = len(enc.encode(truncated_text))
# Should be at or just under 8191 tokens
assert actual_tokens <= 8191
# Should be close to the limit (not over-truncated)
assert actual_tokens >= 8100
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_success(mocker):
"""Test successful embedding storage."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw = mocker.AsyncMock()
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is True
# execute_raw is called twice: once for SET search_path, once for INSERT
assert mock_client.execute_raw.call_count == 2
# First call: SET search_path
first_call_args = mock_client.execute_raw.call_args_list[0][0]
assert "SET search_path" in first_call_args[0]
# Second call: INSERT query with the actual data
second_call_args = mock_client.execute_raw.call_args_list[1][0]
assert "test-version-id" in second_call_args
assert "[0.1,0.2,0.3]" in second_call_args
assert None in second_call_args # userId should be None for store agents
@pytest.mark.asyncio(loop_scope="session")
async def test_store_embedding_database_error(mocker):
"""Test embedding storage with database error."""
mock_client = mocker.AsyncMock()
mock_client.execute_raw.side_effect = Exception("Database error")
embedding = [0.1, 0.2, 0.3]
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_success():
"""Test successful embedding retrieval."""
mock_result = [
{
"contentType": "STORE_AGENT",
"contentId": "test-version-id",
"userId": None,
"embedding": "[0.1,0.2,0.3]",
"searchableText": "Test text",
"metadata": {},
"createdAt": "2024-01-01T00:00:00Z",
"updatedAt": "2024-01-01T00:00:00Z",
}
]
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_result,
):
result = await embeddings.get_embedding("test-version-id")
assert result is not None
assert result["storeListingVersionId"] == "test-version-id"
assert result["embedding"] == "[0.1,0.2,0.3]"
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_not_found():
"""Test embedding retrieval when not found."""
with patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
):
result = await embeddings.get_embedding("test-version-id")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
"""Test ensure_embedding when embedding already exists."""
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_not_called()
mock_store.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.store_content_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
"""Test ensure_embedding creating new embedding."""
mock_get.return_value = None
mock_generate.return_value = [0.1, 0.2, 0.3]
mock_store.return_value = True
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is True
mock_generate.assert_called_once_with("Test Test heading Test description test")
mock_store.assert_called_once_with(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1, 0.2, 0.3],
searchable_text="Test Test heading Test description test",
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
user_id=None,
tx=None,
)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.generate_embedding")
@patch("backend.api.features.store.embeddings.get_embedding")
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.return_value = None
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_get_embedding_stats():
"""Test embedding statistics retrieval."""
# Mock handler stats for each content type
mock_handler = MagicMock()
mock_handler.get_stats = AsyncMock(
return_value={
"total": 100,
"with_embeddings": 75,
"without_embeddings": 25,
}
)
# Patch the CONTENT_HANDLERS where it's used (in embeddings module)
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.get_embedding_stats()
assert "by_type" in result
assert "totals" in result
assert result["totals"]["total"] == 100
assert result["totals"]["with_embeddings"] == 75
assert result["totals"]["without_embeddings"] == 25
assert result["totals"]["coverage_percent"] == 75.0
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.api.features.store.embeddings.store_content_embedding")
async def test_backfill_missing_embeddings_success(mock_store):
"""Test backfill with successful embedding generation."""
# Mock ContentItem from handlers
from backend.api.features.store.content_handlers import ContentItem
mock_items = [
ContentItem(
content_id="version-1",
content_type=ContentType.STORE_AGENT,
searchable_text="Agent 1 Description 1",
metadata={"name": "Agent 1"},
),
ContentItem(
content_id="version-2",
content_type=ContentType.STORE_AGENT,
searchable_text="Agent 2 Description 2",
metadata={"name": "Agent 2"},
),
]
# Mock handler to return missing items
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=mock_items)
# Mock store_content_embedding to succeed for first, fail for second
mock_store.side_effect = [True, False]
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
with patch(
"backend.api.features.store.embeddings.generate_embedding",
return_value=[0.1] * embeddings.EMBEDDING_DIM,
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 2
assert result["success"] == 1
assert result["failed"] == 1
assert mock_store.call_count == 2
@pytest.mark.asyncio(loop_scope="session")
async def test_backfill_missing_embeddings_no_missing():
"""Test backfill when no embeddings are missing."""
# Mock handler to return no missing items
mock_handler = MagicMock()
mock_handler.get_missing_items = AsyncMock(return_value=[])
with patch(
"backend.api.features.store.embeddings.CONTENT_HANDLERS",
{ContentType.STORE_AGENT: mock_handler},
):
result = await embeddings.backfill_missing_embeddings(batch_size=5)
assert result["processed"] == 0
assert result["success"] == 0
assert result["failed"] == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_embedding_to_vector_string():
"""Test embedding to PostgreSQL vector string conversion."""
embedding = [0.1, 0.2, 0.3, -0.4]
result = embeddings.embedding_to_vector_string(embedding)
assert result == "[0.1,0.2,0.3,-0.4]"
@pytest.mark.asyncio(loop_scope="session")
async def test_embed_query():
"""Test embed_query function (alias for generate_embedding)."""
with patch(
"backend.api.features.store.embeddings.generate_embedding"
) as mock_generate:
mock_generate.return_value = [0.1, 0.2, 0.3]
result = await embeddings.embed_query("test query")
assert result == [0.1, 0.2, 0.3]
mock_generate.assert_called_once_with("test query")

View File

@@ -0,0 +1,727 @@
"""
Unified Hybrid Search
Combines semantic (embedding) search with lexical (tsvector) search
for improved relevance across all content types (agents, blocks, docs).
Includes BM25 reranking for improved lexical relevance.
"""
import logging
import re
from dataclasses import dataclass
from typing import Any, Literal
from prisma.enums import ContentType
from rank_bm25 import BM25Okapi
from backend.api.features.store.embeddings import (
EMBEDDING_DIM,
embed_query,
embedding_to_vector_string,
)
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
# ============================================================================
# BM25 Reranking
# ============================================================================
def tokenize(text: str) -> list[str]:
"""Simple tokenizer for BM25 - lowercase and split on non-alphanumeric."""
if not text:
return []
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r"\b\w+\b", text.lower())
return tokens
def bm25_rerank(
query: str,
results: list[dict[str, Any]],
text_field: str = "searchable_text",
bm25_weight: float = 0.3,
original_score_field: str = "combined_score",
) -> list[dict[str, Any]]:
"""
Rerank search results using BM25.
Combines the original combined_score with BM25 score for improved
lexical relevance, especially for exact term matches.
Args:
query: The search query
results: List of result dicts with text_field and original_score_field
text_field: Field name containing the text to score
bm25_weight: Weight for BM25 score (0-1). Original score gets (1 - bm25_weight)
original_score_field: Field name containing the original score
Returns:
Results list sorted by combined score (BM25 + original)
"""
if not results or not query:
return results
# Extract texts and tokenize
corpus = [tokenize(r.get(text_field, "") or "") for r in results]
# Handle edge case where all documents are empty
if all(len(doc) == 0 for doc in corpus):
return results
# Build BM25 index
bm25 = BM25Okapi(corpus)
# Score query against corpus
query_tokens = tokenize(query)
if not query_tokens:
return results
bm25_scores = bm25.get_scores(query_tokens)
# Normalize BM25 scores to 0-1 range
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
normalized_bm25 = [s / max_bm25 for s in bm25_scores]
# Combine scores
original_weight = 1.0 - bm25_weight
for i, result in enumerate(results):
original_score = result.get(original_score_field, 0) or 0
result["bm25_score"] = normalized_bm25[i]
final_score = (
original_weight * original_score + bm25_weight * normalized_bm25[i]
)
result["final_score"] = final_score
result["relevance"] = final_score
# Sort by relevance descending
results.sort(key=lambda x: x.get("relevance", 0), reverse=True)
return results
@dataclass
class UnifiedSearchWeights:
"""Weights for unified search (no popularity signal)."""
semantic: float = 0.40 # Embedding cosine similarity
lexical: float = 0.40 # tsvector ts_rank_cd score
category: float = 0.10 # Category match boost (for types that have categories)
recency: float = 0.10 # Newer content ranked higher
def __post_init__(self):
"""Validate weights are non-negative and sum to approximately 1.0."""
total = self.semantic + self.lexical + self.category + self.recency
if any(
w < 0 for w in [self.semantic, self.lexical, self.category, self.recency]
):
raise ValueError("All weights must be non-negative")
if not (0.99 <= total <= 1.01):
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
# Default weights for unified search
DEFAULT_UNIFIED_WEIGHTS = UnifiedSearchWeights()
# Minimum relevance score thresholds
DEFAULT_MIN_SCORE = 0.15 # For unified search (more permissive)
DEFAULT_STORE_AGENT_MIN_SCORE = 0.20 # For store agent search (original threshold)
async def unified_hybrid_search(
query: str,
content_types: list[ContentType] | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
weights: UnifiedSearchWeights | None = None,
min_score: float | None = None,
user_id: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Unified hybrid search across all content types.
Searches UnifiedContentEmbedding using both semantic (vector) and lexical (tsvector) signals.
Args:
query: Search query string
content_types: List of content types to search. Defaults to all public types.
category: Filter by category (for content types that support it)
page: Page number (1-indexed)
page_size: Results per page
weights: Custom weights for search signals
min_score: Minimum relevance score threshold (0-1)
user_id: User ID for searching private content (library agents)
Returns:
Tuple of (results list, total count)
"""
# Validate inputs
query = query.strip()
if not query:
return [], 0
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
if content_types is None:
content_types = [
ContentType.STORE_AGENT,
ContentType.BLOCK,
ContentType.DOCUMENTATION,
]
if weights is None:
weights = DEFAULT_UNIFIED_WEIGHTS
if min_score is None:
min_score = DEFAULT_MIN_SCORE
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
query_embedding = [0.0] * EMBEDDING_DIM
# Redistribute semantic weight to lexical
total_non_semantic = weights.lexical + weights.category + weights.recency
if total_non_semantic > 0:
factor = 1.0 / total_non_semantic
weights = UnifiedSearchWeights(
semantic=0.0,
lexical=weights.lexical * factor,
category=weights.category * factor,
recency=weights.recency * factor,
)
else:
weights = UnifiedSearchWeights(
semantic=0.0, lexical=1.0, category=0.0, recency=0.0
)
# Build parameters
params: list[Any] = []
param_idx = 1
# Query for lexical search
params.append(query)
query_param = f"${param_idx}"
param_idx += 1
# Query lowercase for category matching
params.append(query.lower())
query_lower_param = f"${param_idx}"
param_idx += 1
# Embedding
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_idx}"
param_idx += 1
# Content types
content_type_values = [ct.value for ct in content_types]
params.append(content_type_values)
content_types_param = f"${param_idx}"
param_idx += 1
# User ID filter (for private content)
user_filter = ""
if user_id is not None:
params.append(user_id)
user_filter = f'AND (uce."userId" = ${param_idx} OR uce."userId" IS NULL)'
param_idx += 1
else:
user_filter = 'AND uce."userId" IS NULL'
# Weights
params.append(weights.semantic)
w_semantic = f"${param_idx}"
param_idx += 1
params.append(weights.lexical)
w_lexical = f"${param_idx}"
param_idx += 1
params.append(weights.category)
w_category = f"${param_idx}"
param_idx += 1
params.append(weights.recency)
w_recency = f"${param_idx}"
param_idx += 1
# Min score
params.append(min_score)
min_score_param = f"${param_idx}"
param_idx += 1
# Pagination
params.append(page_size)
limit_param = f"${param_idx}"
param_idx += 1
params.append(offset)
offset_param = f"${param_idx}"
param_idx += 1
# Unified search query on UnifiedContentEmbedding
sql_query = f"""
WITH candidates AS (
-- Lexical matches (uses GIN index on search column)
SELECT uce.id, uce."contentType", uce."contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
{user_filter}
AND uce.search @@ plainto_tsquery('english', {query_param})
UNION
-- Semantic matches (uses HNSW index on embedding)
(
SELECT uce.id, uce."contentType", uce."contentId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
{user_filter}
ORDER BY uce.embedding <=> {embedding_param}::vector
LIMIT 200
)
),
search_scores AS (
SELECT
uce."contentType" as content_type,
uce."contentId" as content_id,
uce."searchableText" as searchable_text,
uce.metadata,
uce."updatedAt" as updated_at,
-- Semantic score: cosine similarity (1 - distance)
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score: ts_rank_cd
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match from metadata
CASE
WHEN uce.metadata ? 'categories' AND EXISTS (
SELECT 1 FROM jsonb_array_elements_text(uce.metadata->'categories') cat
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
)
THEN 1.0
ELSE 0.0
END as category_score,
-- Recency score: linear decay over 90 days
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - uce."updatedAt")) / (90 * 24 * 3600)) as recency_score
FROM candidates c
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce ON c.id = uce.id
),
max_lexical AS (
SELECT GREATEST(MAX(lexical_raw), 0.001) as max_val FROM search_scores
),
normalized AS (
SELECT
ss.*,
ss.lexical_raw / ml.max_val as lexical_score
FROM search_scores ss
CROSS JOIN max_lexical ml
),
scored AS (
SELECT
content_type,
content_id,
searchable_text,
metadata,
updated_at,
semantic_score,
lexical_score,
category_score,
recency_score,
(
{w_semantic} * semantic_score +
{w_lexical} * lexical_score +
{w_category} * category_score +
{w_recency} * recency_score
) as combined_score
FROM normalized
),
filtered AS (
SELECT *, COUNT(*) OVER () as total_count
FROM scored
WHERE combined_score >= {min_score_param}
)
SELECT * FROM filtered
ORDER BY combined_score DESC
LIMIT {limit_param} OFFSET {offset_param}
"""
results = await query_raw_with_schema(
sql_query, *params, set_public_search_path=True
)
total = results[0]["total_count"] if results else 0
# Apply BM25 reranking
if results:
results = bm25_rerank(
query=query,
results=results,
text_field="searchable_text",
bm25_weight=0.3,
original_score_field="combined_score",
)
# Clean up results
for result in results:
result.pop("total_count", None)
logger.info(f"Unified hybrid search: {len(results)} results, {total} total")
return results, total
# ============================================================================
# Store Agent specific search (with full metadata)
# ============================================================================
@dataclass
class StoreAgentSearchWeights:
"""Weights for store agent search including popularity."""
semantic: float = 0.30
lexical: float = 0.30
category: float = 0.20
recency: float = 0.10
popularity: float = 0.10
def __post_init__(self):
total = (
self.semantic
+ self.lexical
+ self.category
+ self.recency
+ self.popularity
)
if any(
w < 0
for w in [
self.semantic,
self.lexical,
self.category,
self.recency,
self.popularity,
]
):
raise ValueError("All weights must be non-negative")
if not (0.99 <= total <= 1.01):
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
DEFAULT_STORE_AGENT_WEIGHTS = StoreAgentSearchWeights()
async def hybrid_search(
query: str,
featured: bool = False,
creators: list[str] | None = None,
category: str | None = None,
sorted_by: (
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
) = None,
page: int = 1,
page_size: int = 20,
weights: StoreAgentSearchWeights | None = None,
min_score: float | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
Hybrid search for store agents with full metadata.
Uses UnifiedContentEmbedding for search, joins to StoreAgent for metadata.
"""
query = query.strip()
if not query:
return [], 0
if page < 1:
page = 1
if page_size < 1:
page_size = 1
if page_size > 100:
page_size = 100
if weights is None:
weights = DEFAULT_STORE_AGENT_WEIGHTS
if min_score is None:
min_score = (
DEFAULT_STORE_AGENT_MIN_SCORE # Use original threshold for store agents
)
offset = (page - 1) * page_size
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation
if query_embedding is None or not query_embedding:
logger.warning(
"Failed to generate query embedding - falling back to lexical-only search."
)
query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = (
weights.lexical + weights.category + weights.recency + weights.popularity
)
if total_non_semantic > 0:
factor = 1.0 / total_non_semantic
weights = StoreAgentSearchWeights(
semantic=0.0,
lexical=weights.lexical * factor,
category=weights.category * factor,
recency=weights.recency * factor,
popularity=weights.popularity * factor,
)
else:
weights = StoreAgentSearchWeights(
semantic=0.0, lexical=1.0, category=0.0, recency=0.0, popularity=0.0
)
# Build parameters
params: list[Any] = []
param_idx = 1
params.append(query)
query_param = f"${param_idx}"
param_idx += 1
params.append(query.lower())
query_lower_param = f"${param_idx}"
param_idx += 1
embedding_str = embedding_to_vector_string(query_embedding)
params.append(embedding_str)
embedding_param = f"${param_idx}"
param_idx += 1
# Build WHERE clause for StoreAgent filters
where_parts = ["sa.is_available = true"]
if featured:
where_parts.append("sa.featured = true")
if creators:
params.append(creators)
where_parts.append(f"sa.creator_username = ANY(${param_idx})")
param_idx += 1
if category:
params.append(category)
where_parts.append(f"${param_idx} = ANY(sa.categories)")
param_idx += 1
where_clause = " AND ".join(where_parts)
# Weights
params.append(weights.semantic)
w_semantic = f"${param_idx}"
param_idx += 1
params.append(weights.lexical)
w_lexical = f"${param_idx}"
param_idx += 1
params.append(weights.category)
w_category = f"${param_idx}"
param_idx += 1
params.append(weights.recency)
w_recency = f"${param_idx}"
param_idx += 1
params.append(weights.popularity)
w_popularity = f"${param_idx}"
param_idx += 1
params.append(min_score)
min_score_param = f"${param_idx}"
param_idx += 1
params.append(page_size)
limit_param = f"${param_idx}"
param_idx += 1
params.append(offset)
offset_param = f"${param_idx}"
param_idx += 1
# Query using UnifiedContentEmbedding for search, StoreAgent for metadata
sql_query = f"""
WITH candidates AS (
-- Lexical matches via UnifiedContentEmbedding.search
SELECT uce."contentId" as "storeListingVersionId"
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND uce.search @@ plainto_tsquery('english', {query_param})
AND {where_clause}
UNION
-- Semantic matches via UnifiedContentEmbedding.embedding
SELECT uce."contentId" as "storeListingVersionId"
FROM (
SELECT uce."contentId", uce.embedding
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON uce."contentId" = sa."storeListingVersionId"
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
AND uce."userId" IS NULL
AND {where_clause}
ORDER BY uce.embedding <=> {embedding_param}::vector
LIMIT 200
) uce
),
search_scores AS (
SELECT
sa.slug,
sa.agent_name,
sa.agent_image,
sa.creator_username,
sa.creator_avatar,
sa.sub_heading,
sa.description,
sa.runs,
sa.rating,
sa.categories,
sa.featured,
sa.is_available,
sa.updated_at,
-- Searchable text for BM25 reranking
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
-- Semantic score
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
-- Lexical score (raw, will normalize)
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
-- Category match
CASE
WHEN EXISTS (
SELECT 1 FROM unnest(sa.categories) cat
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
)
THEN 1.0
ELSE 0.0
END as category_score,
-- Recency
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
-- Popularity (raw)
sa.runs as popularity_raw
FROM candidates c
INNER JOIN {{schema_prefix}}"StoreAgent" sa
ON c."storeListingVersionId" = sa."storeListingVersionId"
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
ON sa."storeListingVersionId" = uce."contentId"
AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
),
max_vals AS (
SELECT
GREATEST(MAX(lexical_raw), 0.001) as max_lexical,
GREATEST(MAX(popularity_raw), 1) as max_popularity
FROM search_scores
),
normalized AS (
SELECT
ss.*,
ss.lexical_raw / mv.max_lexical as lexical_score,
CASE
WHEN ss.popularity_raw > 0
THEN LN(1 + ss.popularity_raw) / LN(1 + mv.max_popularity)
ELSE 0
END as popularity_score
FROM search_scores ss
CROSS JOIN max_vals mv
),
scored AS (
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
searchable_text,
semantic_score,
lexical_score,
category_score,
recency_score,
popularity_score,
(
{w_semantic} * semantic_score +
{w_lexical} * lexical_score +
{w_category} * category_score +
{w_recency} * recency_score +
{w_popularity} * popularity_score
) as combined_score
FROM normalized
),
filtered AS (
SELECT *, COUNT(*) OVER () as total_count
FROM scored
WHERE combined_score >= {min_score_param}
)
SELECT * FROM filtered
ORDER BY combined_score DESC
LIMIT {limit_param} OFFSET {offset_param}
"""
results = await query_raw_with_schema(
sql_query, *params, set_public_search_path=True
)
total = results[0]["total_count"] if results else 0
# Apply BM25 reranking
if results:
results = bm25_rerank(
query=query,
results=results,
text_field="searchable_text",
bm25_weight=0.3,
original_score_field="combined_score",
)
for result in results:
result.pop("total_count", None)
result.pop("searchable_text", None)
logger.info(f"Hybrid search (store agents): {len(results)} results, {total} total")
return results, total
async def hybrid_search_simple(
query: str,
page: int = 1,
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""Simplified hybrid search for store agents."""
return await hybrid_search(query=query, page=page, page_size=page_size)
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
# for existing code that expects the popularity parameter
HybridSearchWeights = StoreAgentSearchWeights

View File

@@ -0,0 +1,726 @@
"""
Integration tests for hybrid search with schema handling.
These tests verify that hybrid search works correctly across different database schemas.
"""
from unittest.mock import patch
import pytest
from prisma.enums import ContentType
from backend.api.features.store import embeddings
from backend.api.features.store.hybrid_search import (
HybridSearchWeights,
UnifiedSearchWeights,
hybrid_search,
unified_hybrid_search,
)
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_schema_handling():
"""Test that hybrid search correctly handles database schema prefixes."""
# Test with a mock query to ensure schema handling works
query = "test agent"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Mock the query result
mock_query.return_value = [
{
"slug": "test/agent",
"agent_name": "Test Agent",
"agent_image": "test.png",
"creator_username": "test",
"creator_avatar": "avatar.png",
"sub_heading": "Test sub-heading",
"description": "Test description",
"runs": 10,
"rating": 4.5,
"categories": ["test"],
"featured": False,
"is_available": True,
"updated_at": "2024-01-01T00:00:00Z",
"combined_score": 0.8,
"semantic_score": 0.7,
"lexical_score": 0.6,
"category_score": 0.5,
"recency_score": 0.4,
"total_count": 1,
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM # Mock embedding
results, total = await hybrid_search(
query=query,
page=1,
page_size=20,
)
# Verify the query was called
assert mock_query.called
# Verify the SQL template uses schema_prefix placeholder
call_args = mock_query.call_args
sql_template = call_args[0][0]
assert "{schema_prefix}" in sql_template
# Verify results
assert len(results) == 1
assert total == 1
assert results[0]["slug"] == "test/agent"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_public_schema():
"""Test hybrid search when using public schema (no prefix needed)."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "public"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "public"
# Results should work even with empty results
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_custom_schema():
"""Test hybrid search when using custom schema (e.g., 'platform')."""
with patch("backend.data.db.get_database_schema") as mock_schema:
mock_schema.return_value = "platform"
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify the mock was set up correctly
assert mock_schema.return_value == "platform"
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_without_embeddings():
"""Test hybrid search gracefully degrades when embeddings are unavailable."""
# Mock database to return some results
mock_results = [
{
"slug": "test-agent",
"agent_name": "Test Agent",
"agent_image": "test.png",
"creator_username": "creator",
"creator_avatar": "avatar.png",
"sub_heading": "Test heading",
"description": "Test description",
"runs": 100,
"rating": 4.5,
"categories": ["AI"],
"featured": False,
"is_available": True,
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.0, # Zero because no embedding
"lexical_score": 0.5,
"category_score": 0.0,
"recency_score": 0.1,
"popularity_score": 0.2,
"combined_score": 0.3,
"total_count": 1,
}
]
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate embedding failure
mock_embed.return_value = None
mock_query.return_value = mock_results
# Should NOT raise - graceful degradation
results, total = await hybrid_search(
query="test",
page=1,
page_size=20,
)
# Verify it returns results even without embeddings
assert len(results) == 1
assert results[0]["slug"] == "test-agent"
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_with_filters():
"""Test hybrid search with various filters."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test with featured filter
results, total = await hybrid_search(
query="test",
featured=True,
creators=["user1", "user2"],
category="productivity",
page=1,
page_size=10,
)
# Verify filters were applied in the query
call_args = mock_query.call_args
params = call_args[0][1:] # Skip SQL template
# Should have query, query_lower, creators array, category
assert len(params) >= 4
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_weights():
"""Test hybrid search with custom weights."""
custom_weights = HybridSearchWeights(
semantic=0.5,
lexical=0.3,
category=0.1,
recency=0.1,
popularity=0.0,
)
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = []
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await hybrid_search(
query="test",
weights=custom_weights,
page=1,
page_size=20,
)
# Verify custom weights were used in the query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters passed
# Check that SQL uses parameterized weights (not f-string interpolation)
assert "$" in sql_template # Verify parameterization is used
# Check that custom weights are in the params
assert 0.5 in params # semantic weight
assert 0.3 in params # lexical weight
assert 0.1 in params # category and recency weights
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_min_score_filtering():
"""Test hybrid search minimum score threshold."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Return results with varying scores
mock_query.return_value = [
{
"slug": "high-score/agent",
"agent_name": "High Score Agent",
"combined_score": 0.8,
"total_count": 1,
# ... other fields
}
]
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test with custom min_score
results, total = await hybrid_search(
query="test",
min_score=0.5, # High threshold
page=1,
page_size=20,
)
# Verify min_score was applied in query
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:] # Get all parameters
# Check that SQL uses parameterized min_score
assert "combined_score >=" in sql_template
assert "$" in sql_template # Verify parameterization
# Check that custom min_score is in the params
assert 0.5 in params
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_pagination():
"""Test hybrid search pagination.
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
to the paginated results.
"""
# Create mock results that SQL would return for a page
mock_results = [
{
"slug": f"agent-{i}",
"agent_name": f"Agent {i}",
"agent_image": "test.png",
"creator_username": "test",
"creator_avatar": "avatar.png",
"sub_heading": "Test",
"description": "Test description",
"runs": 100 - i,
"rating": 4.5,
"categories": ["test"],
"featured": False,
"is_available": True,
"updated_at": "2024-01-01T00:00:00Z",
"searchable_text": f"Agent {i} test description",
"combined_score": 0.9 - (i * 0.01),
"semantic_score": 0.7,
"lexical_score": 0.6,
"category_score": 0.5,
"recency_score": 0.4,
"popularity_score": 0.3,
"total_count": 25,
}
for i in range(10) # SQL returns page_size results
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
mock_query.return_value = mock_results
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Test page 2 with page_size 10
results, total = await hybrid_search(
query="test",
page=2,
page_size=10,
)
# Verify results returned
assert len(results) == 10
assert total == 25 # Total from SQL COUNT(*) OVER()
# Verify the SQL query uses page_size and offset
call_args = mock_query.call_args
params = call_args[0]
# Last two params are page_size and offset
page_size_param = params[-2]
offset_param = params[-1]
assert page_size_param == 10
assert offset_param == 10 # (page 2 - 1) * 10
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_hybrid_search_error_handling():
"""Test hybrid search error handling."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate database error
mock_query.side_effect = Exception("Database connection error")
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
# Should raise exception
with pytest.raises(Exception) as exc_info:
await hybrid_search(
query="test",
page=1,
page_size=20,
)
assert "Database connection error" in str(exc_info.value)
# =============================================================================
# Unified Hybrid Search Tests
# =============================================================================
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_basic():
"""Test basic unified hybrid search across all content types."""
mock_results = [
{
"content_type": "STORE_AGENT",
"content_id": "agent-1",
"searchable_text": "Test Agent Description",
"metadata": {"name": "Test Agent"},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.5,
"recency_score": 0.3,
"combined_score": 0.6,
"total_count": 2,
},
{
"content_type": "BLOCK",
"content_id": "block-1",
"searchable_text": "Test Block Description",
"metadata": {"name": "Test Block"},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.6,
"lexical_score": 0.7,
"category_score": 0.4,
"recency_score": 0.2,
"combined_score": 0.5,
"total_count": 2,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
assert len(results) == 2
assert total == 2
assert results[0]["content_type"] == "STORE_AGENT"
assert results[1]["content_type"] == "BLOCK"
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_filter_by_content_type():
"""Test unified search filtering by specific content types."""
mock_results = [
{
"content_type": "BLOCK",
"content_id": "block-1",
"searchable_text": "Test Block",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.3,
"combined_score": 0.5,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
content_types=[ContentType.BLOCK],
page=1,
page_size=20,
)
# Verify content_types parameter was passed correctly
call_args = mock_query.call_args
params = call_args[0][1:]
# The content types should be in the params as a list
assert ["BLOCK"] in params
assert len(results) == 1
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_with_user_id():
"""Test unified search with user_id for private content."""
mock_results = [
{
"content_type": "STORE_AGENT",
"content_id": "agent-1",
"searchable_text": "My Private Agent",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.3,
"combined_score": 0.6,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
user_id="user-123",
page=1,
page_size=20,
)
# Verify SQL contains user_id filter
call_args = mock_query.call_args
sql_template = call_args[0][0]
params = call_args[0][1:]
assert 'uce."userId"' in sql_template
assert "user-123" in params
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_custom_weights():
"""Test unified search with custom weights."""
custom_weights = UnifiedSearchWeights(
semantic=0.6,
lexical=0.2,
category=0.1,
recency=0.1,
)
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = []
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
weights=custom_weights,
page=1,
page_size=20,
)
# Verify custom weights are in parameters
call_args = mock_query.call_args
params = call_args[0][1:]
assert 0.6 in params # semantic weight
assert 0.2 in params # lexical weight
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_graceful_degradation():
"""Test unified search gracefully degrades when embeddings unavailable."""
mock_results = [
{
"content_type": "DOCUMENTATION",
"content_id": "doc-1",
"searchable_text": "API Documentation",
"metadata": {},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.0, # Zero because no embedding
"lexical_score": 0.8,
"category_score": 0.0,
"recency_score": 0.2,
"combined_score": 0.5,
"total_count": 1,
},
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = None # Embedding failure
# Should NOT raise - graceful degradation
results, total = await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
assert len(results) == 1
assert total == 1
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_empty_query():
"""Test unified search with empty query returns empty results."""
results, total = await unified_hybrid_search(
query="",
page=1,
page_size=20,
)
assert results == []
assert total == 0
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_pagination():
"""Test unified search pagination with BM25 reranking.
Pagination happens in SQL (LIMIT/OFFSET), then BM25 reranking is applied
to the paginated results.
"""
# Create mock results that SQL would return for a page
mock_results = [
{
"content_type": "STORE_AGENT",
"content_id": f"agent-{i}",
"searchable_text": f"Agent {i} description",
"metadata": {"name": f"Agent {i}"},
"updated_at": "2025-01-01T00:00:00Z",
"semantic_score": 0.7,
"lexical_score": 0.8 - (i * 0.01),
"category_score": 0.5,
"recency_score": 0.3,
"combined_score": 0.6 - (i * 0.01),
"total_count": 50,
}
for i in range(15) # SQL returns page_size results
]
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
results, total = await unified_hybrid_search(
query="test",
page=3,
page_size=15,
)
# Verify results returned
assert len(results) == 15
assert total == 50 # Total from SQL COUNT(*) OVER()
# Verify the SQL query uses page_size and offset
call_args = mock_query.call_args
params = call_args[0]
# Last two params are page_size and offset
page_size_param = params[-2]
offset_param = params[-1]
assert page_size_param == 15
assert offset_param == 30 # (page 3 - 1) * 15
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.integration
async def test_unified_hybrid_search_schema_prefix():
"""Test unified search uses schema_prefix placeholder."""
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
with patch(
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = []
mock_embed.return_value = [0.1] * embeddings.EMBEDDING_DIM
await unified_hybrid_search(
query="test",
page=1,
page_size=20,
)
call_args = mock_query.call_args
sql_template = call_args[0][0]
# Verify schema_prefix placeholder is used for table references
assert "{schema_prefix}" in sql_template
assert '"UnifiedContentEmbedding"' in sql_template
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -110,6 +110,7 @@ class Profile(pydantic.BaseModel):
class StoreSubmission(pydantic.BaseModel):
listing_id: str
agent_id: str
agent_version: int
name: str
@@ -164,8 +165,12 @@ class StoreListingsWithVersionsResponse(pydantic.BaseModel):
class StoreSubmissionRequest(pydantic.BaseModel):
agent_id: str
agent_version: int
agent_id: str = pydantic.Field(
..., min_length=1, description="Agent ID cannot be empty"
)
agent_version: int = pydantic.Field(
..., gt=0, description="Agent version must be greater than 0"
)
slug: str
name: str
sub_heading: str
@@ -216,3 +221,23 @@ class ReviewSubmissionRequest(pydantic.BaseModel):
is_approved: bool
comments: str # External comments visible to creator
internal_comments: str | None = None # Private admin notes
class UnifiedSearchResult(pydantic.BaseModel):
"""A single result from unified hybrid search across all content types."""
content_type: str # STORE_AGENT, BLOCK, DOCUMENTATION
content_id: str
searchable_text: str
metadata: dict | None = None
updated_at: datetime.datetime | None = None
combined_score: float | None = None
semantic_score: float | None = None
lexical_score: float | None = None
class UnifiedSearchResponse(pydantic.BaseModel):
"""Response model for unified search across all content types."""
results: list[UnifiedSearchResult]
pagination: Pagination

View File

@@ -138,6 +138,7 @@ def test_creator_details():
def test_store_submission():
submission = store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",
@@ -159,6 +160,7 @@ def test_store_submissions_response():
response = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="listing123",
agent_id="agent123",
agent_version=1,
sub_heading="Test subheading",

View File

@@ -7,12 +7,15 @@ from typing import Literal
import autogpt_libs.auth
import fastapi
import fastapi.responses
import prisma.enums
import backend.data.graph
import backend.util.json
from backend.util.models import Pagination
from . import cache as store_cache
from . import db as store_db
from . import hybrid_search as store_hybrid_search
from . import image_gen as store_image_gen
from . import media as store_media
from . import model as store_model
@@ -146,6 +149,102 @@ async def get_agents(
return agents
##############################################
############### Search Endpoints #############
##############################################
@router.get(
"/search",
summary="Unified search across all content types",
tags=["store", "public"],
response_model=store_model.UnifiedSearchResponse,
)
async def unified_search(
query: str,
content_types: list[str] | None = fastapi.Query(
default=None,
description="Content types to search: STORE_AGENT, BLOCK, DOCUMENTATION. If not specified, searches all.",
),
page: int = 1,
page_size: int = 20,
user_id: str | None = fastapi.Security(
autogpt_libs.auth.get_optional_user_id, use_cache=False
),
):
"""
Search across all content types (store agents, blocks, documentation) using hybrid search.
Combines semantic (embedding-based) and lexical (text-based) search for best results.
Args:
query: The search query string
content_types: Optional list of content types to filter by (STORE_AGENT, BLOCK, DOCUMENTATION)
page: Page number for pagination (default 1)
page_size: Number of results per page (default 20)
user_id: Optional authenticated user ID (for user-scoped content in future)
Returns:
UnifiedSearchResponse: Paginated list of search results with relevance scores
"""
if page < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page must be greater than 0"
)
if page_size < 1:
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
# Convert string content types to enum
content_type_enums: list[prisma.enums.ContentType] | None = None
if content_types:
try:
content_type_enums = [prisma.enums.ContentType(ct) for ct in content_types]
except ValueError as e:
raise fastapi.HTTPException(
status_code=422,
detail=f"Invalid content type. Valid values: STORE_AGENT, BLOCK, DOCUMENTATION. Error: {e}",
)
# Perform unified hybrid search
results, total = await store_hybrid_search.unified_hybrid_search(
query=query,
content_types=content_type_enums,
user_id=user_id,
page=page,
page_size=page_size,
)
# Convert results to response model
search_results = [
store_model.UnifiedSearchResult(
content_type=r["content_type"],
content_id=r["content_id"],
searchable_text=r.get("searchable_text", ""),
metadata=r.get("metadata"),
updated_at=r.get("updated_at"),
combined_score=r.get("combined_score"),
semantic_score=r.get("semantic_score"),
lexical_score=r.get("lexical_score"),
)
for r in results
]
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
return store_model.UnifiedSearchResponse(
results=search_results,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
@router.get(
"/agents/{username}/{agent_name}",
summary="Get specific agent",

View File

@@ -521,6 +521,7 @@ def test_get_submissions_success(
mocked_value = store_model.StoreSubmissionsResponse(
submissions=[
store_model.StoreSubmission(
listing_id="test-listing-id",
name="Test Agent",
description="Test agent description",
image_urls=["test.jpg"],

View File

@@ -0,0 +1,272 @@
"""Tests for the semantic_search function."""
import pytest
from prisma.enums import ContentType
from backend.api.features.store.embeddings import EMBEDDING_DIM, semantic_search
@pytest.mark.asyncio
async def test_search_blocks_only(mocker):
"""Test searching only BLOCK content type."""
# Mock embed_query to return a test embedding
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Mock query_raw_with_schema to return test results
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block - Performs arithmetic operations",
"metadata": {"name": "Calculator", "categories": ["Math"]},
"similarity": 0.85,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculate numbers",
content_types=[ContentType.BLOCK],
)
assert len(results) == 1
assert results[0]["content_type"] == "BLOCK"
assert results[0]["content_id"] == "block-123"
assert results[0]["similarity"] == 0.85
@pytest.mark.asyncio
async def test_search_multiple_content_types(mocker):
"""Test searching multiple content types simultaneously."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block",
"metadata": {},
"similarity": 0.85,
},
{
"content_id": "doc-456",
"content_type": "DOCUMENTATION",
"searchable_text": "How to use Calculator",
"metadata": {},
"similarity": 0.75,
},
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculator",
content_types=[ContentType.BLOCK, ContentType.DOCUMENTATION],
)
assert len(results) == 2
assert results[0]["content_type"] == "BLOCK"
assert results[1]["content_type"] == "DOCUMENTATION"
@pytest.mark.asyncio
async def test_search_with_min_similarity_threshold(mocker):
"""Test that results below min_similarity are filtered out."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Only return results above 0.7 similarity
mock_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block",
"metadata": {},
"similarity": 0.85,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="calculate",
content_types=[ContentType.BLOCK],
min_similarity=0.7,
)
assert len(results) == 1
assert results[0]["similarity"] >= 0.7
@pytest.mark.asyncio
async def test_search_fallback_to_lexical(mocker):
"""Test fallback to lexical search when embeddings fail."""
# Mock embed_query to return None (embeddings unavailable)
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=None,
)
mock_lexical_results = [
{
"content_id": "block-123",
"content_type": "BLOCK",
"searchable_text": "Calculator Block performs calculations",
"metadata": {},
"similarity": 0.0,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_lexical_results,
)
results = await semantic_search(
query="calculator",
content_types=[ContentType.BLOCK],
)
assert len(results) == 1
assert results[0]["similarity"] == 0.0 # Lexical search returns 0 similarity
@pytest.mark.asyncio
async def test_search_empty_query():
"""Test that empty query returns no results."""
results = await semantic_search(query="")
assert results == []
results = await semantic_search(query=" ")
assert results == []
@pytest.mark.asyncio
async def test_search_with_user_id_filter(mocker):
"""Test searching with user_id filter for private content."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_results = [
{
"content_id": "agent-789",
"content_type": "LIBRARY_AGENT",
"searchable_text": "My Custom Agent",
"metadata": {},
"similarity": 0.9,
}
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="custom agent",
content_types=[ContentType.LIBRARY_AGENT],
user_id="user-123",
)
assert len(results) == 1
assert results[0]["content_type"] == "LIBRARY_AGENT"
@pytest.mark.asyncio
async def test_search_limit_parameter(mocker):
"""Test that limit parameter correctly limits results."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Return 5 results
mock_results = [
{
"content_id": f"block-{i}",
"content_type": "BLOCK",
"searchable_text": f"Block {i}",
"metadata": {},
"similarity": 0.8,
}
for i in range(5)
]
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=mock_results,
)
results = await semantic_search(
query="block",
content_types=[ContentType.BLOCK],
limit=5,
)
assert len(results) == 5
@pytest.mark.asyncio
async def test_search_default_content_types(mocker):
"""Test that default content_types includes BLOCK, STORE_AGENT, and DOCUMENTATION."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
mock_query_raw = mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
return_value=[],
)
await semantic_search(query="test")
# Check that the SQL query includes all three default content types
call_args = mock_query_raw.call_args
assert "BLOCK" in str(call_args)
assert "STORE_AGENT" in str(call_args)
assert "DOCUMENTATION" in str(call_args)
@pytest.mark.asyncio
async def test_search_handles_database_error(mocker):
"""Test that database errors are handled gracefully."""
mock_embedding = [0.1] * EMBEDDING_DIM
mocker.patch(
"backend.api.features.store.embeddings.embed_query",
return_value=mock_embedding,
)
# Simulate database error
mocker.patch(
"backend.api.features.store.embeddings.query_raw_with_schema",
side_effect=Exception("Database connection failed"),
)
results = await semantic_search(
query="test",
content_types=[ContentType.BLOCK],
)
# Should return empty list on error
assert results == []

View File

@@ -64,7 +64,6 @@ from backend.data.onboarding import (
complete_re_run_agent,
get_recommended_agents,
get_user_onboarding,
increment_runs,
onboarding_enabled,
reset_user_onboarding,
update_user_onboarding,
@@ -975,7 +974,6 @@ async def execute_graph(
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success")
await increment_runs(user_id)
await complete_re_run_agent(user_id, graph_id)
if source == "library":
await complete_onboarding_step(

View File

@@ -39,7 +39,7 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -113,7 +113,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
with launch_darkly_context():

View File

@@ -1,6 +1,7 @@
from typing import Any
from backend.blocks.llm import (
DEFAULT_LLM_MODEL,
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
AIBlockBase,
@@ -49,7 +50,7 @@ class AIConditionBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for evaluating the condition.",
advanced=False,
)
@@ -81,7 +82,7 @@ class AIConditionBlock(AIBlockBase):
"condition": "the input is an email address",
"yes_value": "Valid email",
"no_value": "Not an email",
"model": LlmModel.GPT4O,
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,

View File

@@ -6,6 +6,9 @@ import hashlib
import hmac
import logging
from enum import Enum
from typing import cast
from prisma.types import Serializable
from backend.sdk import (
BaseWebhooksManager,
@@ -84,7 +87,9 @@ class AirtableWebhookManager(BaseWebhooksManager):
# update webhook config
await update_webhook(
webhook.id,
config={"base_id": base_id, "cursor": response.cursor},
config=cast(
dict[str, Serializable], {"base_id": base_id, "cursor": response.cursor}
),
)
event_type = "notification"

View File

@@ -182,13 +182,10 @@ class DataForSeoRelatedKeywordsBlock(Block):
if results and len(results) > 0:
# results is a list, get the first element
first_result = results[0] if isinstance(results, list) else results
items = (
first_result.get("items", [])
if isinstance(first_result, dict)
else []
)
# Ensure items is never None
if items is None:
# Handle missing key, null value, or valid list value
if isinstance(first_result, dict):
items = first_result.get("items") or []
else:
items = []
for item in items:
# Extract keyword_data from the item

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,184 @@
"""
Shared helpers for Human-In-The-Loop (HITL) review functionality.
Used by both the dedicated HumanInTheLoopBlock and blocks that require human review.
"""
import logging
from typing import Any, Optional
from prisma.enums import ReviewStatus
from pydantic import BaseModel
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.human_review import ReviewResult
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
class ReviewDecision(BaseModel):
"""Result of a review decision."""
should_proceed: bool
message: str
review_result: ReviewResult
class HITLReviewHelper:
"""Helper class for Human-In-The-Loop review operations."""
@staticmethod
async def get_or_create_human_review(**kwargs) -> Optional[ReviewResult]:
"""Create or retrieve a human review from the database."""
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
@staticmethod
async def update_node_execution_status(**kwargs) -> None:
"""Update the execution status of a node."""
await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
@staticmethod
async def update_review_processed_status(
node_exec_id: str, processed: bool
) -> None:
"""Update the processed status of a review."""
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
@staticmethod
async def _handle_review_request(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewResult]:
"""
Handle a review request for a block that requires human review.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewResult if review is complete, None if waiting for human input
Raises:
Exception: If review creation or status update fails
"""
# Skip review if safe mode is disabled - return auto-approved result
if not execution_context.safe_mode:
logger.info(
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
)
return ReviewResult(
data=input_data,
status=ReviewStatus.APPROVED,
message="Auto-approved (safe mode disabled)",
processed=True,
node_exec_id=node_exec_id,
)
result = await HITLReviewHelper.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data,
message=f"Review required for {block_name} execution",
editable=editable,
)
if result is None:
logger.info(
f"Block {block_name} pausing execution for node {node_exec_id} - awaiting human review"
)
await HITLReviewHelper.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return None # Signal that execution should pause
# Mark review as processed if not already done
if not result.processed:
await HITLReviewHelper.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
return result
@staticmethod
async def handle_review_decision(
input_data: Any,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
block_name: str = "Block",
editable: bool = False,
) -> Optional[ReviewDecision]:
"""
Handle a review request and return the decision in a single call.
Args:
input_data: The input data to be reviewed
user_id: ID of the user requesting the review
node_exec_id: ID of the node execution
graph_exec_id: ID of the graph execution
graph_id: ID of the graph
graph_version: Version of the graph
execution_context: Current execution context
block_name: Name of the block requesting review
editable: Whether the reviewer can edit the data
Returns:
ReviewDecision if review is complete (approved/rejected),
None if execution should pause (awaiting review)
"""
review_result = await HITLReviewHelper._handle_review_request(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=block_name,
editable=editable,
)
if review_result is None:
# Still awaiting review - return None to pause execution
return None
# Review is complete, determine outcome
should_proceed = review_result.status == ReviewStatus.APPROVED
message = review_result.message or (
"Execution approved by reviewer"
if should_proceed
else "Execution rejected by reviewer"
)
return ReviewDecision(
should_proceed=should_proceed, message=message, review_result=review_result
)

View File

@@ -3,6 +3,7 @@ from typing import Any
from prisma.enums import ReviewStatus
from backend.blocks.helpers.review import HITLReviewHelper
from backend.data.block import (
Block,
BlockCategory,
@@ -11,11 +12,9 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.execution import ExecutionContext
from backend.data.human_review import ReviewResult
from backend.data.model import SchemaField
from backend.executor.manager import async_update_node_execution_status
from backend.util.clients import get_database_manager_async_client
logger = logging.getLogger(__name__)
@@ -72,32 +71,26 @@ class HumanInTheLoopBlock(Block):
("approved_data", {"name": "John Doe", "age": 30}),
],
test_mock={
"get_or_create_human_review": lambda *_args, **_kwargs: ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
"update_node_execution_status": lambda *_args, **_kwargs: None,
"update_review_processed_status": lambda *_args, **_kwargs: None,
"handle_review_decision": lambda **kwargs: type(
"ReviewDecision",
(),
{
"should_proceed": True,
"message": "Test approval message",
"review_result": ReviewResult(
data={"name": "John Doe", "age": 30},
status=ReviewStatus.APPROVED,
message="",
processed=False,
node_exec_id="test-node-exec-id",
),
},
)(),
},
)
async def get_or_create_human_review(self, **kwargs):
return await get_database_manager_async_client().get_or_create_human_review(
**kwargs
)
async def update_node_execution_status(self, **kwargs):
return await async_update_node_execution_status(
db_client=get_database_manager_async_client(), **kwargs
)
async def update_review_processed_status(self, node_exec_id: str, processed: bool):
return await get_database_manager_async_client().update_review_processed_status(
node_exec_id, processed
)
async def handle_review_decision(self, **kwargs):
return await HITLReviewHelper.handle_review_decision(**kwargs)
async def run(
self,
@@ -109,7 +102,7 @@ class HumanInTheLoopBlock(Block):
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
**kwargs,
**_kwargs,
) -> BlockOutput:
if not execution_context.safe_mode:
logger.info(
@@ -119,48 +112,28 @@ class HumanInTheLoopBlock(Block):
yield "review_message", "Auto-approved (safe mode disabled)"
return
try:
result = await self.get_or_create_human_review(
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
input_data=input_data.data,
message=input_data.name,
editable=input_data.editable,
)
except Exception as e:
logger.error(f"Error in HITL block for node {node_exec_id}: {str(e)}")
raise
decision = await self.handle_review_decision(
input_data=input_data.data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=input_data.editable,
)
if result is None:
logger.info(
f"HITL block pausing execution for node {node_exec_id} - awaiting human review"
)
try:
await self.update_node_execution_status(
exec_id=node_exec_id,
status=ExecutionStatus.REVIEW,
)
return
except Exception as e:
logger.error(
f"Failed to update node status for HITL block {node_exec_id}: {str(e)}"
)
raise
if decision is None:
return
if not result.processed:
await self.update_review_processed_status(
node_exec_id=node_exec_id, processed=True
)
status = decision.review_result.status
if status == ReviewStatus.APPROVED:
yield "approved_data", decision.review_result.data
elif status == ReviewStatus.REJECTED:
yield "rejected_data", decision.review_result.data
else:
raise RuntimeError(f"Unexpected review status: {status}")
if result.status == ReviewStatus.APPROVED:
yield "approved_data", result.data
if result.message:
yield "review_message", result.message
elif result.status == ReviewStatus.REJECTED:
yield "rejected_data", result.data
if result.message:
yield "review_message", result.message
if decision.message:
yield "review_message", decision.message

View File

@@ -92,8 +92,9 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
O1 = "o1"
O1_MINI = "o1-mini"
# GPT-5 models
GPT5 = "gpt-5-2025-08-07"
GPT5_2 = "gpt-5.2-2025-12-11"
GPT5_1 = "gpt-5.1-2025-11-13"
GPT5 = "gpt-5-2025-08-07"
GPT5_MINI = "gpt-5-mini-2025-08-07"
GPT5_NANO = "gpt-5-nano-2025-08-07"
GPT5_CHAT = "gpt-5-chat-latest"
@@ -194,8 +195,9 @@ MODEL_METADATA = {
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
# GPT-5 models
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_2: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_1: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_MINI: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_NANO: ModelMetadata("openai", 400000, 128000),
LlmModel.GPT5_CHAT: ModelMetadata("openai", 400000, 16384),
@@ -303,6 +305,8 @@ MODEL_METADATA = {
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000),
}
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
for model in LlmModel:
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
@@ -790,7 +794,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -855,7 +859,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
input_schema=AIStructuredResponseGeneratorBlock.Input,
output_schema=AIStructuredResponseGeneratorBlock.Output,
test_input={
"model": LlmModel.GPT4O,
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"expected_format": {
"key1": "value1",
@@ -1221,7 +1225,7 @@ class AITextGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -1317,7 +1321,7 @@ class AITextSummarizerBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for summarizing the text.",
)
focus: str = SchemaField(
@@ -1534,7 +1538,7 @@ class AIConversationBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for the conversation.",
)
credentials: AICredentials = AICredentialsField()
@@ -1572,7 +1576,7 @@ class AIConversationBlock(AIBlockBase):
},
{"role": "user", "content": "Where was it played?"},
],
"model": LlmModel.GPT4O,
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
@@ -1635,7 +1639,7 @@ class AIListGeneratorBlock(AIBlockBase):
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
default=DEFAULT_LLM_MODEL,
description="The language model to use for generating the list.",
advanced=True,
)
@@ -1692,7 +1696,7 @@ class AIListGeneratorBlock(AIBlockBase):
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
"fictional worlds."
),
"model": LlmModel.GPT4O,
"model": DEFAULT_LLM_MODEL,
"credentials": TEST_CREDENTIALS_INPUT,
"max_retries": 3,
"force_json_output": False,

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ from backend.data.model import (
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import DEFAULT_USER_AGENT
class GetWikipediaSummaryBlock(Block, GetRequest):
@@ -39,17 +40,27 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
output_schema=GetWikipediaSummaryBlock.Output,
test_input={"topic": "Artificial Intelligence"},
test_output=("summary", "summary content"),
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
test_mock={
"get_request": lambda url, headers, json: {"extract": "summary content"}
},
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
topic = input_data.topic
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
# URL-encode the topic to handle spaces and special characters
encoded_topic = quote(topic, safe="")
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{encoded_topic}"
# Set headers per Wikimedia robot policy (https://w.wiki/4wJS)
# - User-Agent: Required, must identify the bot
# - Accept-Encoding: gzip recommended to reduce bandwidth
headers = {
"User-Agent": DEFAULT_USER_AGENT,
"Accept-Encoding": "gzip, deflate",
}
# Note: User-Agent is now automatically set by the request library
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
try:
response = await self.get_request(url, json=True)
response = await self.get_request(url, headers=headers, json=True)
if "extract" not in response:
raise ValueError(f"Unable to parse Wikipedia response: {response}")
yield "summary", response["extract"]

View File

@@ -226,7 +226,7 @@ class SmartDecisionMakerBlock(Block):
)
model: llm.LlmModel = SchemaField(
title="LLM Model",
default=llm.LlmModel.GPT4O,
default=llm.DEFAULT_LLM_MODEL,
description="The language model to use for answering the prompt.",
advanced=False,
)
@@ -391,8 +391,12 @@ class SmartDecisionMakerBlock(Block):
"""
block = sink_node.block
# Use custom name from node metadata if set, otherwise fall back to block.name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else block.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(block.name),
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
@@ -489,14 +493,24 @@ class SmartDecisionMakerBlock(Block):
f"Sink graph metadata not found: {graph_id} {graph_version}"
)
# Use custom name from node metadata if set, otherwise fall back to graph name
custom_name = sink_node.metadata.get("customized_name")
tool_name = custom_name if custom_name else sink_graph_meta.name
tool_function: dict[str, Any] = {
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"name": SmartDecisionMakerBlock.cleanup(tool_name),
"description": sink_graph_meta.description,
}
properties = {}
field_mapping = {}
for link in links:
field_name = link.sink_name
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
field_mapping[clean_field_name] = field_name
sink_block_input_schema = sink_node.input_default["input_schema"]
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
@@ -506,7 +520,7 @@ class SmartDecisionMakerBlock(Block):
if "description" in sink_block_properties
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name] = {
properties[clean_field_name] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
@@ -519,7 +533,7 @@ class SmartDecisionMakerBlock(Block):
"strict": True,
}
# Store node info for later use in output processing
tool_function["_field_mapping"] = field_mapping
tool_function["_sink_node_id"] = sink_node.id
return {"type": "function", "function": tool_function}
@@ -975,10 +989,28 @@ class SmartDecisionMakerBlock(Block):
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
nodes_to_skip: set[str] | None = None,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id)
original_tool_count = len(tool_functions)
# Filter out tools for nodes that should be skipped (e.g., missing optional credentials)
if nodes_to_skip:
tool_functions = [
tf
for tf in tool_functions
if tf.get("function", {}).get("_sink_node_id") not in nodes_to_skip
]
# Only raise error if we had tools but they were all filtered out
if original_tool_count > 0 and not tool_functions:
raise ValueError(
"No available tools to execute - all downstream nodes are unavailable "
"(possibly due to missing optional credentials)"
)
yield "tool_functions", json.dumps(tool_functions)
conversation_history = input_data.conversation_history or []
@@ -1129,8 +1161,9 @@ class SmartDecisionMakerBlock(Block):
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
arg_value = tool_args.get(clean_arg_name)
sanitized_arg_name = self.cleanup(original_field_name)
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
# Use original_field_name directly (not sanitized) to match link sink_name
# The field_mapping already translates from LLM's cleaned names to original names
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
logger.debug(
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",

View File

@@ -196,6 +196,15 @@ class TestXMLParserBlockSecurity:
async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
pass
async def test_rejects_text_outside_root(self):
"""Ensure parser surfaces readable errors for invalid root text."""
block = XMLParserBlock()
invalid_xml = "<root><child>value</child></root> trailing"
with pytest.raises(ValueError, match="text outside the root element"):
async for _ in block.run(XMLParserBlock.Input(input_xml=invalid_xml)):
pass
class TestStoreMediaFileSecurity:
"""Test file storage security limits."""

View File

@@ -28,7 +28,7 @@ class TestLLMStatsTracking:
response = await llm.llm_call(
credentials=llm.TEST_CREDENTIALS,
llm_model=llm.LlmModel.GPT4O,
llm_model=llm.DEFAULT_LLM_MODEL,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
@@ -65,7 +65,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore # type: ignore
)
@@ -109,7 +109,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AITextGeneratorBlock.Input(
prompt="Generate text",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -170,7 +170,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
@@ -228,7 +228,7 @@ class TestLLMStatsTracking:
input_data = llm.AITextSummarizerBlock.Input(
text=long_text,
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=100, # Small chunks
chunk_overlap=10,
@@ -299,7 +299,7 @@ class TestLLMStatsTracking:
# Test with very short text (should only need 1 chunk + 1 final summary)
input_data = llm.AITextSummarizerBlock.Input(
text="This is a short text.",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000, # Large enough to avoid chunking
)
@@ -346,7 +346,7 @@ class TestLLMStatsTracking:
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -387,7 +387,7 @@ class TestLLMStatsTracking:
# Run the block
input_data = llm.AIListGeneratorBlock.Input(
focus="test items",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_retries=3,
)
@@ -469,7 +469,7 @@ class TestLLMStatsTracking:
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test",
expected_format={"result": "desc"},
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -513,7 +513,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
)
@@ -558,7 +558,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
style=llm.SummaryStyle.BULLET_POINTS,
max_tokens=1000,
@@ -593,7 +593,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)
@@ -623,7 +623,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
max_tokens=1000,
)
@@ -654,7 +654,7 @@ class TestAITextSummarizerValidation:
# Create input data
input_data = llm.AITextSummarizerBlock.Input(
text="Some text to summarize",
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
)

View File

@@ -233,7 +233,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
# Create test input
input_data = SmartDecisionMakerBlock.Input(
prompt="Should I continue with this task?",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -335,7 +335,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
@@ -402,7 +402,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -462,7 +462,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -526,7 +526,7 @@ async def test_smart_decision_maker_parameter_validation():
input_data = SmartDecisionMakerBlock.Input(
prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -648,7 +648,7 @@ async def test_smart_decision_maker_raw_response_conversion():
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
agent_mode_max_iterations=0,
@@ -722,7 +722,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Simple prompt",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -778,7 +778,7 @@ async def test_smart_decision_maker_raw_response_conversion():
):
input_data = SmartDecisionMakerBlock.Input(
prompt="Another test",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
@@ -931,7 +931,7 @@ async def test_smart_decision_maker_agent_mode():
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
)
@@ -1020,7 +1020,7 @@ async def test_smart_decision_maker_traditional_mode_default():
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
model=llm_module.DEFAULT_LLM_MODEL,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
@@ -1057,3 +1057,153 @@ async def test_smart_decision_maker_traditional_mode_default():
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_blocks():
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_falls_back_to_block_name():
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
from unittest.mock import MagicMock
from backend.blocks.basic import StoreValueBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-node-id"
mock_node.block_id = StoreValueBlock().id
mock_node.metadata = {} # No customized_name
mock_node.block = StoreValueBlock()
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "input"
# Call the function directly
result = await SmartDecisionMakerBlock._create_block_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the block's default name
assert result["type"] == "function"
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
assert result["function"]["_sink_node_id"] == "test-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_uses_customized_name_for_agents():
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node with customized_name in metadata
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {"customized_name": "My Custom Agent"}
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the customized name (cleaned up)
assert result["type"] == "function"
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
"""Test that agent node falls back to graph name when no customized_name."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.data.graph import Link, Node
# Create a mock node without customized_name
mock_node = MagicMock(spec=Node)
mock_node.id = "test-agent-node-id"
mock_node.metadata = {} # No customized_name
mock_node.input_default = {
"graph_id": "test-graph-id",
"graph_version": 1,
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
}
# Create a mock link
mock_link = MagicMock(spec=Link)
mock_link.sink_name = "test_input"
# Mock the database client
mock_graph_meta = MagicMock()
mock_graph_meta.name = "Original Agent Name"
mock_graph_meta.description = "Agent description"
mock_db_client = AsyncMock()
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
):
result = await SmartDecisionMakerBlock._create_agent_function_signature(
mock_node, [mock_link]
)
# Verify the tool name uses the graph's default name
assert result["type"] == "function"
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
assert result["function"]["_sink_node_id"] == "test-agent-node-id"

View File

@@ -15,6 +15,7 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields
mock_links = [
@@ -77,6 +78,7 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [

View File

@@ -44,6 +44,7 @@ async def test_create_block_function_signature_with_dict_fields():
mock_node.block = CreateDictionaryBlock()
mock_node.block_id = CreateDictionaryBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
mock_links = [
@@ -106,6 +107,7 @@ async def test_create_block_function_signature_with_list_fields():
mock_node.block = AddToListBlock()
mock_node.block_id = AddToListBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic list fields
mock_links = [
@@ -159,6 +161,7 @@ async def test_create_block_function_signature_with_object_fields():
mock_node.block = MatchTextPatternBlock()
mock_node.block_id = MatchTextPatternBlock().id
mock_node.input_default = {}
mock_node.metadata = {}
# Create mock links with dynamic object fields
mock_links = [
@@ -208,11 +211,13 @@ async def test_create_tool_node_signatures():
mock_dict_node.block = CreateDictionaryBlock()
mock_dict_node.block_id = CreateDictionaryBlock().id
mock_dict_node.input_default = {}
mock_dict_node.metadata = {}
mock_list_node = Mock()
mock_list_node.block = AddToListBlock()
mock_list_node.block_id = AddToListBlock().id
mock_list_node.input_default = {}
mock_list_node.metadata = {}
# Mock links with dynamic fields
dict_link1 = Mock(
@@ -373,7 +378,7 @@ async def test_output_yielding_with_dynamic_fields():
input_data = block.input_schema(
prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
)
@@ -423,6 +428,7 @@ async def test_mixed_regular_and_dynamic_fields():
mock_node.block.name = "TestBlock"
mock_node.block.description = "A test block"
mock_node.block.input_schema = Mock()
mock_node.metadata = {}
# Mock the get_field_schema to return a proper schema for regular fields
def get_field_schema(field_name):
@@ -594,7 +600,7 @@ async def test_validation_errors_dont_pollute_conversation():
input_data = block.input_schema(
prompt="Test prompt",
credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O,
model=llm.DEFAULT_LLM_MODEL,
retry=3, # Allow retries
agent_mode_max_iterations=1,
)

View File

@@ -1,3 +1,3 @@
from .blog import WordPressCreatePostBlock
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
__all__ = ["WordPressCreatePostBlock"]
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]

View File

@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
grant_type="authorization_code",
).model_dump(exclude_none=True)
response = await Requests().post(
response = await Requests(raise_for_status=False).post(
f"{WORDPRESS_BASE_URL}oauth2/token",
headers=headers,
data=data,
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
grant_type="refresh_token",
).model_dump(exclude_none=True)
response = await Requests().post(
response = await Requests(raise_for_status=False).post(
f"{WORDPRESS_BASE_URL}oauth2/token",
headers=headers,
data=data,
@@ -252,7 +252,7 @@ async def validate_token(
"token": token,
}
response = await Requests().get(
response = await Requests(raise_for_status=False).get(
f"{WORDPRESS_BASE_URL}oauth2/token-info",
params=params,
)
@@ -296,7 +296,7 @@ async def make_api_request(
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
request_method = getattr(Requests(), method.lower())
request_method = getattr(Requests(raise_for_status=False), method.lower())
response = await request_method(
url,
headers=headers,
@@ -476,6 +476,7 @@ async def create_post(
data["tags"] = ",".join(str(t) for t in data["tags"])
# Make the API request
site = normalize_site(site)
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
headers = {
@@ -483,7 +484,7 @@ async def create_post(
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests().post(
response = await Requests(raise_for_status=False).post(
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
headers=headers,
data=data,
@@ -499,3 +500,132 @@ async def create_post(
)
error_message = error_data.get("message", response.text)
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
class Post(BaseModel):
"""Response model for individual posts in a posts list response.
This is a simplified version compared to PostResponse, as the list endpoint
returns less detailed information than the create/get single post endpoints.
"""
ID: int
site_ID: int
author: PostAuthor
date: datetime
modified: datetime
title: str
URL: str
short_URL: str
content: str | None = None
excerpt: str | None = None
slug: str
guid: str
status: str
sticky: bool
password: str | None = ""
parent: Union[Dict[str, Any], bool, None] = None
type: str
discussion: Dict[str, Union[str, bool, int]] | None = None
likes_enabled: bool | None = None
sharing_enabled: bool | None = None
like_count: int | None = None
i_like: bool | None = None
is_reblogged: bool | None = None
is_following: bool | None = None
global_ID: str | None = None
featured_image: str | None = None
post_thumbnail: Dict[str, Any] | None = None
format: str | None = None
geo: Union[Dict[str, Any], bool, None] = None
menu_order: int | None = None
page_template: str | None = None
publicize_URLs: List[str] | None = None
terms: Dict[str, Dict[str, Any]] | None = None
tags: Dict[str, Dict[str, Any]] | None = None
categories: Dict[str, Dict[str, Any]] | None = None
attachments: Dict[str, Dict[str, Any]] | None = None
attachment_count: int | None = None
metadata: List[Dict[str, Any]] | None = None
meta: Dict[str, Any] | None = None
capabilities: Dict[str, bool] | None = None
revisions: List[int] | None = None
other_URLs: Dict[str, Any] | None = None
class PostsResponse(BaseModel):
"""Response model for WordPress posts list."""
found: int
posts: List[Post]
meta: Dict[str, Any]
def normalize_site(site: str) -> str:
"""
Normalize a site identifier by stripping protocol and trailing slashes.
Args:
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
Returns:
Normalized site identifier (domain or ID only)
"""
site = site.strip()
if site.startswith("https://"):
site = site[8:]
elif site.startswith("http://"):
site = site[7:]
return site.rstrip("/")
async def get_posts(
credentials: Credentials,
site: str,
status: PostStatus | None = None,
number: int = 100,
offset: int = 0,
) -> PostsResponse:
"""
Get posts from a WordPress site.
Args:
credentials: OAuth credentials
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
status: Filter by post status using PostStatus enum, or None for all
number: Number of posts to retrieve (max 100)
offset: Number of posts to skip (for pagination)
Returns:
PostsResponse with the list of posts
"""
site = normalize_site(site)
endpoint = f"/rest/v1.1/sites/{site}/posts"
headers = {
"Authorization": credentials.auth_header(),
}
params: Dict[str, Any] = {
"number": max(1, min(number, 100)), # 1100 posts per request
"offset": offset,
}
if status:
params["status"] = status.value
response = await Requests(raise_for_status=False).get(
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
headers=headers,
params=params,
)
if response.ok:
return PostsResponse.model_validate(response.json())
error_data = (
response.json()
if response.headers.get("content-type", "").startswith("application/json")
else {}
)
error_message = error_data.get("message", response.text)
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")

View File

@@ -9,7 +9,15 @@ from backend.sdk import (
SchemaField,
)
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
from ._api import (
CreatePostRequest,
Post,
PostResponse,
PostsResponse,
PostStatus,
create_post,
get_posts,
)
from ._config import wordpress
@@ -49,8 +57,15 @@ class WordPressCreatePostBlock(Block):
media_urls: list[str] = SchemaField(
description="URLs of images to sideload and attach to the post", default=[]
)
publish_as_draft: bool = SchemaField(
description="If True, publishes the post as a draft. If False, publishes it publicly.",
default=False,
)
class Output(BlockSchemaOutput):
site: str = SchemaField(
description="The site ID or domain (pass-through for chaining with other blocks)"
)
post_id: int = SchemaField(description="The ID of the created post")
post_url: str = SchemaField(description="The full URL of the created post")
short_url: str = SchemaField(description="The shortened wp.me URL")
@@ -78,7 +93,9 @@ class WordPressCreatePostBlock(Block):
tags=input_data.tags,
featured_image=input_data.featured_image,
media_urls=input_data.media_urls,
status=PostStatus.PUBLISH,
status=(
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
),
)
post_response: PostResponse = await create_post(
@@ -87,7 +104,69 @@ class WordPressCreatePostBlock(Block):
post_data=post_request,
)
yield "site", input_data.site
yield "post_id", post_response.ID
yield "post_url", post_response.URL
yield "short_url", post_response.short_URL
yield "post_data", post_response.model_dump()
class WordPressGetAllPostsBlock(Block):
"""
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
Supports filtering by status and pagination.
"""
class Input(BlockSchemaInput):
credentials: CredentialsMetaInput = wordpress.credentials_field()
site: str = SchemaField(
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
)
status: PostStatus | None = SchemaField(
description="Filter by post status, or None for all",
default=None,
)
number: int = SchemaField(
description="Number of posts to retrieve (max 100 per request)", default=20
)
offset: int = SchemaField(
description="Number of posts to skip (for pagination)", default=0
)
class Output(BlockSchemaOutput):
site: str = SchemaField(
description="The site ID or domain (pass-through for chaining with other blocks)"
)
found: int = SchemaField(description="Total number of posts found")
posts: list[Post] = SchemaField(
description="List of post objects with their details"
)
post: Post = SchemaField(
description="Individual post object (yielded for each post)"
)
def __init__(self):
super().__init__(
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
description="Fetch all posts from WordPress.com or Jetpack sites",
categories={BlockCategory.SOCIAL},
input_schema=self.Input,
output_schema=self.Output,
)
async def run(
self, input_data: Input, *, credentials: Credentials, **kwargs
) -> BlockOutput:
posts_response: PostsResponse = await get_posts(
credentials=credentials,
site=input_data.site,
status=input_data.status,
number=input_data.number,
offset=input_data.offset,
)
yield "site", input_data.site
yield "found", posts_response.found
yield "posts", posts_response.posts
for post in posts_response.posts:
yield "post", post

View File

@@ -1,5 +1,5 @@
from gravitasml.parser import Parser
from gravitasml.token import tokenize
from gravitasml.token import Token, tokenize
from backend.data.block import Block, BlockOutput, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import SchemaField
@@ -25,6 +25,38 @@ class XMLParserBlock(Block):
],
)
@staticmethod
def _validate_tokens(tokens: list[Token]) -> None:
"""Ensure the XML has a single root element and no stray text."""
if not tokens:
raise ValueError("XML input is empty.")
depth = 0
root_seen = False
for token in tokens:
if token.type == "TAG_OPEN":
if depth == 0 and root_seen:
raise ValueError("XML must have a single root element.")
depth += 1
if depth == 1:
root_seen = True
elif token.type == "TAG_CLOSE":
depth -= 1
if depth < 0:
raise SyntaxError("Unexpected closing tag in XML input.")
elif token.type in {"TEXT", "ESCAPE"}:
if depth == 0 and token.value:
raise ValueError(
"XML contains text outside the root element; "
"wrap content in a single root tag."
)
if depth != 0:
raise SyntaxError("Unclosed tag detected in XML input.")
if not root_seen:
raise ValueError("XML must include a root element.")
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Security fix: Add size limits to prevent XML bomb attacks
MAX_XML_SIZE = 10 * 1024 * 1024 # 10MB limit for XML input
@@ -35,7 +67,9 @@ class XMLParserBlock(Block):
)
try:
tokens = tokenize(input_data.input_xml)
tokens = list(tokenize(input_data.input_xml))
self._validate_tokens(tokens)
parser = Parser(tokens)
parsed_result = parser.parse()
yield "parsed_xml", parsed_result

View File

@@ -111,6 +111,8 @@ class TranscribeYoutubeVideoBlock(Block):
return parsed_url.path.split("/")[2]
if parsed_url.path[:3] == "/v/":
return parsed_url.path.split("/")[2]
if parsed_url.path.startswith("/shorts/"):
return parsed_url.path.split("/")[2]
raise ValueError(f"Invalid YouTube URL: {url}")
def get_transcript(

View File

@@ -50,6 +50,8 @@ from .model import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from backend.data.execution import ExecutionContext
from .graph import Link
app_config = Config()
@@ -472,6 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
self.requires_human_review: bool = False
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -614,7 +617,88 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
) from ex
async def is_block_exec_need_review(
self,
input_data: BlockInput,
*,
user_id: str,
node_exec_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
execution_context: "ExecutionContext",
**kwargs,
) -> tuple[bool, BlockInput]:
"""
Check if this block execution needs human review and handle the review process.
Returns:
Tuple of (should_pause, input_data_to_use)
- should_pause: True if execution should be paused for review
- input_data_to_use: The input data to use (may be modified by reviewer)
"""
# Skip review if not required or safe mode is disabled
if not self.requires_human_review or not execution_context.safe_mode:
return False, input_data
from backend.blocks.helpers.review import HITLReviewHelper
# Handle the review request and get decision
decision = await HITLReviewHelper.handle_review_decision(
input_data=input_data,
user_id=user_id,
node_exec_id=node_exec_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
block_name=self.name,
editable=True,
)
if decision is None:
# We're awaiting review - pause execution
return True, input_data
if not decision.should_proceed:
# Review was rejected, raise an error to stop execution
raise BlockExecutionError(
message=f"Block execution rejected by reviewer: {decision.message}",
block_name=self.name,
block_id=self.id,
)
# Review was approved - use the potentially modified data
# ReviewResult.data must be a dict for block inputs
reviewed_data = decision.review_result.data
if not isinstance(reviewed_data, dict):
raise BlockExecutionError(
message=f"Review data must be a dict for block input, got {type(reviewed_data).__name__}",
block_name=self.name,
block_id=self.id,
)
return False, reviewed_data
async def _execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Check for review requirement only if running within a graph execution context
# Direct block execution (e.g., from chat) skips the review process
has_graph_context = all(
key in kwargs
for key in (
"node_exec_id",
"graph_exec_id",
"graph_id",
"execution_context",
)
)
if has_graph_context:
should_pause, input_data = await self.is_block_exec_need_review(
input_data, **kwargs
)
if should_pause:
return
# Validate the input data (original or reviewer-modified) once
if error := self.input_schema.validate_data(input_data):
raise BlockInputError(
message=f"Unable to execute block with invalid input data: {error}",
@@ -622,6 +706,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
block_id=self.id,
)
# Use the validated input data
async for output_name, output_data in self.run(
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,

View File

@@ -59,12 +59,13 @@ from backend.integrations.credentials_store import (
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 4,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O3_MINI: 2,
LlmModel.O1: 16,
LlmModel.O1_MINI: 4,
# GPT-5 models
LlmModel.GPT5: 2,
LlmModel.GPT5_2: 6,
LlmModel.GPT5_1: 5,
LlmModel.GPT5: 2,
LlmModel.GPT5_MINI: 1,
LlmModel.GPT5_NANO: 1,
LlmModel.GPT5_CHAT: 5,
@@ -87,7 +88,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.AIML_API_LLAMA3_3_70B: 1,
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
LlmModel.LLAMA3_3_70B: 1, # $0.59 / $0.79
LlmModel.LLAMA3_3_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_3: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,

View File

@@ -341,6 +341,19 @@ class UserCreditBase(ABC):
if result:
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if transaction.amount > 0 and transaction.type in [
CreditTransactionType.GRANT,
CreditTransactionType.TOP_UP,
]:
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return result[0]["balance"]
async def _add_transaction(
@@ -530,6 +543,22 @@ class UserCreditBase(ABC):
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
# Clear insufficient funds notification flags when credits are added
# so user can receive alerts again if they run out in the future.
if (
amount > 0
and is_active
and transaction_type
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
):
# Lazy import to avoid circular dependency with executor.manager
from backend.executor.manager import (
clear_insufficient_funds_notifications,
)
await clear_insufficient_funds_notifications(user_id)
return new_balance, tx_key
# If no result, either user doesn't exist or insufficient balance

View File

@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
if POOL_TIMEOUT:
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
# Add public schema to search_path for pgvector type access
# The vector extension is in public schema, but search_path is determined by schema parameter
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
parsed_url = urlparse(DATABASE_URL)
url_params = dict(parse_qsl(parsed_url.query))
db_schema = url_params.get("schema", "public")
# Build search_path, avoiding duplicates if db_schema is already 'public'
search_path_schemas = list(
dict.fromkeys([db_schema, "public"])
) # Preserves order, removes duplicates
search_path = ",".join(search_path_schemas)
# This allows using ::vector without schema qualification
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
prisma = Prisma(
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
return query_params.get("schema", "public")
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
"""Execute raw SQL query with proper schema handling."""
async def _raw_with_schema(
query_template: str,
*args,
execute: bool = False,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> list[dict] | int:
"""Internal: Execute raw SQL with proper schema handling.
Use query_raw_with_schema() or execute_raw_with_schema() instead.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
client: Optional Prisma client for transactions (only used when execute=True).
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
- list[dict] if execute=False (query results)
- int if execute=True (number of affected rows)
"""
schema = get_database_schema()
schema_prefix = f'"{schema}".' if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module
result = await prisma_module.get_client().query_raw(
formatted_query, *args # type: ignore
)
db_client = client if client else prisma_module.get_client()
# Set search_path to include public schema if requested
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
# This is idempotent and safe to call multiple times
if set_public_search_path:
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
if execute:
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
else:
result = await db_client.query_raw(formatted_query, *args) # type: ignore
return result
async def query_raw_with_schema(
query_template: str, *args, set_public_search_path: bool = False
) -> list[dict]:
"""Execute raw SQL SELECT query with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
List of result rows as dictionaries
Example:
results = await query_raw_with_schema(
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
user_id
)
"""
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
async def execute_raw_with_schema(
query_template: str,
*args,
client: Prisma | None = None,
set_public_search_path: bool = False,
) -> int:
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
Args:
query_template: SQL query with {schema_prefix} placeholder
*args: Query parameters
client: Optional Prisma client for transactions
set_public_search_path: If True, sets search_path to include public schema.
Needed for pgvector types and other public schema objects.
Returns:
Number of affected rows
Example:
await execute_raw_with_schema(
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
user_id, name,
client=tx # Optional transaction client
)
"""
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))

View File

@@ -383,6 +383,7 @@ class GraphExecutionWithNodes(GraphExecution):
self,
execution_context: ExecutionContext,
compiled_nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
):
return GraphExecutionEntry(
user_id=self.user_id,
@@ -390,6 +391,7 @@ class GraphExecutionWithNodes(GraphExecution):
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip or set(),
execution_context=execution_context,
)
@@ -1145,6 +1147,8 @@ class GraphExecutionEntry(BaseModel):
graph_id: str
graph_version: int
nodes_input_masks: Optional[NodesInputMasks] = None
nodes_to_skip: set[str] = Field(default_factory=set)
"""Node IDs that should be skipped due to optional credentials not being configured."""
execution_context: ExecutionContext = Field(default_factory=ExecutionContext)

View File

@@ -94,6 +94,15 @@ class Node(BaseDbModel):
input_links: list[Link] = []
output_links: list[Link] = []
@property
def credentials_optional(self) -> bool:
"""
Whether credentials are optional for this node.
When True and credentials are not configured, the node will be skipped
during execution rather than causing a validation error.
"""
return self.metadata.get("credentials_optional", False)
@property
def block(self) -> AnyBlockSchema | "_UnknownBlockBase":
"""Get the block for this node. Returns UnknownBlock if block is deleted/missing."""
@@ -235,7 +244,10 @@ class BaseGraph(BaseDbModel):
return any(
node.block_id
for node in self.nodes
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
if (
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
or node.block.requires_human_review
)
)
@property
@@ -326,7 +338,35 @@ class Graph(BaseGraph):
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
schema = self._credentials_input_schema.jsonschema()
# Determine which credential fields are required based on credentials_optional metadata
graph_credentials_inputs = self.aggregate_credentials_inputs()
required_fields = []
# Build a map of node_id -> node for quick lookup
all_nodes = {node.id: node for node in self.nodes}
for sub_graph in self.sub_graphs:
for node in sub_graph.nodes:
all_nodes[node.id] = node
for field_key, (
_field_info,
node_field_pairs,
) in graph_credentials_inputs.items():
# A field is required if ANY node using it has credentials_optional=False
is_required = False
for node_id, _field_name in node_field_pairs:
node = all_nodes.get(node_id)
if node and not node.credentials_optional:
is_required = True
break
if is_required:
required_fields.append(field_key)
schema["required"] = required_fields
return schema
@property
def _credentials_input_schema(self) -> type[BlockSchema]:

View File

@@ -1,5 +1,6 @@
import json
from typing import Any
from unittest.mock import AsyncMock, patch
from uuid import UUID
import fastapi.exceptions
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
from backend.util.test import SpinTestServer
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
@pytest.mark.asyncio(loop_scope="session")
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
"""
@@ -396,3 +408,58 @@ async def test_access_store_listing_graph(server: SpinTestServer):
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
def test_node_credentials_optional_default():
"""Test that credentials_optional defaults to False when not set in metadata."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={},
)
assert node.credentials_optional is False
def test_node_credentials_optional_true():
"""Test that credentials_optional returns True when explicitly set."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": True},
)
assert node.credentials_optional is True
def test_node_credentials_optional_false():
"""Test that credentials_optional returns False when explicitly set to False."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={"credentials_optional": False},
)
assert node.credentials_optional is False
def test_node_credentials_optional_with_other_metadata():
"""Test that credentials_optional works correctly with other metadata present."""
node = Node(
id="test_node",
block_id=StoreValueBlock().id,
input_default={},
metadata={
"position": {"x": 100, "y": 200},
"customized_name": "My Custom Node",
"credentials_optional": True,
},
)
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"

View File

@@ -334,7 +334,7 @@ async def _get_user_timezone(user_id: str) -> str:
return get_user_timezone_or_utc(user.timezone if user else None)
async def increment_runs(user_id: str):
async def increment_onboarding_runs(user_id: str):
"""
Increment a user's run counters and trigger any onboarding milestones.
"""

View File

@@ -0,0 +1,404 @@
"""Data models and access layer for user business understanding."""
import logging
from datetime import datetime
from typing import Any, Optional, cast
import pydantic
from prisma.models import CoPilotUnderstanding
from backend.data.redis_client import get_redis_async
from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
# Cache configuration
CACHE_KEY_PREFIX = "understanding"
CACHE_TTL_SECONDS = 48 * 60 * 60 # 48 hours
def _cache_key(user_id: str) -> str:
"""Generate cache key for user business understanding."""
return f"{CACHE_KEY_PREFIX}:{user_id}"
def _json_to_list(value: Any) -> list[str]:
"""Convert Json field to list[str], handling None."""
if value is None:
return []
if isinstance(value, list):
return cast(list[str], value)
return []
class BusinessUnderstandingInput(pydantic.BaseModel):
"""Input model for updating business understanding - all fields optional for incremental updates."""
# User info
user_name: Optional[str] = pydantic.Field(None, description="The user's name")
job_title: Optional[str] = pydantic.Field(None, description="The user's job title")
# Business basics
business_name: Optional[str] = pydantic.Field(
None, description="Name of the user's business"
)
industry: Optional[str] = pydantic.Field(None, description="Industry or sector")
business_size: Optional[str] = pydantic.Field(
None, description="Company size (e.g., '1-10', '11-50')"
)
user_role: Optional[str] = pydantic.Field(
None,
description="User's role in the organization (e.g., 'decision maker', 'implementer')",
)
# Processes & activities
key_workflows: Optional[list[str]] = pydantic.Field(
None, description="Key business workflows"
)
daily_activities: Optional[list[str]] = pydantic.Field(
None, description="Daily activities performed"
)
# Pain points & goals
pain_points: Optional[list[str]] = pydantic.Field(
None, description="Current pain points"
)
bottlenecks: Optional[list[str]] = pydantic.Field(
None, description="Process bottlenecks"
)
manual_tasks: Optional[list[str]] = pydantic.Field(
None, description="Manual/repetitive tasks"
)
automation_goals: Optional[list[str]] = pydantic.Field(
None, description="Desired automation goals"
)
# Current tools
current_software: Optional[list[str]] = pydantic.Field(
None, description="Software/tools currently used"
)
existing_automation: Optional[list[str]] = pydantic.Field(
None, description="Existing automations"
)
# Additional context
additional_notes: Optional[str] = pydantic.Field(
None, description="Any additional context"
)
class BusinessUnderstanding(pydantic.BaseModel):
"""Full business understanding model returned from database."""
id: str
user_id: str
created_at: datetime
updated_at: datetime
# User info
user_name: Optional[str] = None
job_title: Optional[str] = None
# Business basics
business_name: Optional[str] = None
industry: Optional[str] = None
business_size: Optional[str] = None
user_role: Optional[str] = None
# Processes & activities
key_workflows: list[str] = pydantic.Field(default_factory=list)
daily_activities: list[str] = pydantic.Field(default_factory=list)
# Pain points & goals
pain_points: list[str] = pydantic.Field(default_factory=list)
bottlenecks: list[str] = pydantic.Field(default_factory=list)
manual_tasks: list[str] = pydantic.Field(default_factory=list)
automation_goals: list[str] = pydantic.Field(default_factory=list)
# Current tools
current_software: list[str] = pydantic.Field(default_factory=list)
existing_automation: list[str] = pydantic.Field(default_factory=list)
# Additional context
additional_notes: Optional[str] = None
@classmethod
def from_db(cls, db_record: CoPilotUnderstanding) -> "BusinessUnderstanding":
"""Convert database record to Pydantic model."""
data = db_record.data if isinstance(db_record.data, dict) else {}
business = (
data.get("business", {}) if isinstance(data.get("business"), dict) else {}
)
return cls(
id=db_record.id,
user_id=db_record.userId,
created_at=db_record.createdAt,
updated_at=db_record.updatedAt,
user_name=data.get("name"),
job_title=business.get("job_title"),
business_name=business.get("business_name"),
industry=business.get("industry"),
business_size=business.get("business_size"),
user_role=business.get("user_role"),
key_workflows=_json_to_list(business.get("key_workflows")),
daily_activities=_json_to_list(business.get("daily_activities")),
pain_points=_json_to_list(business.get("pain_points")),
bottlenecks=_json_to_list(business.get("bottlenecks")),
manual_tasks=_json_to_list(business.get("manual_tasks")),
automation_goals=_json_to_list(business.get("automation_goals")),
current_software=_json_to_list(business.get("current_software")),
existing_automation=_json_to_list(business.get("existing_automation")),
additional_notes=business.get("additional_notes"),
)
def _merge_lists(existing: list | None, new: list | None) -> list | None:
"""Merge two lists, removing duplicates while preserving order."""
if new is None:
return existing
if existing is None:
return new
# Preserve order, add new items that don't exist
merged = list(existing)
for item in new:
if item not in merged:
merged.append(item)
return merged
async def _get_from_cache(user_id: str) -> Optional[BusinessUnderstanding]:
"""Get business understanding from Redis cache."""
try:
redis = await get_redis_async()
cached_data = await redis.get(_cache_key(user_id))
if cached_data:
return BusinessUnderstanding.model_validate_json(cached_data)
except Exception as e:
logger.warning(f"Failed to get understanding from cache: {e}")
return None
async def _set_cache(user_id: str, understanding: BusinessUnderstanding) -> None:
"""Set business understanding in Redis cache with TTL."""
try:
redis = await get_redis_async()
await redis.setex(
_cache_key(user_id),
CACHE_TTL_SECONDS,
understanding.model_dump_json(),
)
except Exception as e:
logger.warning(f"Failed to set understanding in cache: {e}")
async def _delete_cache(user_id: str) -> None:
"""Delete business understanding from Redis cache."""
try:
redis = await get_redis_async()
await redis.delete(_cache_key(user_id))
except Exception as e:
logger.warning(f"Failed to delete understanding from cache: {e}")
async def get_business_understanding(
user_id: str,
) -> Optional[BusinessUnderstanding]:
"""Get the business understanding for a user.
Checks cache first, falls back to database if not cached.
Results are cached for 48 hours.
"""
# Try cache first
cached = await _get_from_cache(user_id)
if cached:
logger.debug(f"Business understanding cache hit for user {user_id}")
return cached
# Cache miss - load from database
logger.debug(f"Business understanding cache miss for user {user_id}")
record = await CoPilotUnderstanding.prisma().find_unique(where={"userId": user_id})
if record is None:
return None
understanding = BusinessUnderstanding.from_db(record)
# Store in cache for next time
await _set_cache(user_id, understanding)
return understanding
async def upsert_business_understanding(
user_id: str,
input_data: BusinessUnderstandingInput,
) -> BusinessUnderstanding:
"""
Create or update business understanding with incremental merge strategy.
- String fields: new value overwrites if provided (not None)
- List fields: new items are appended to existing (deduplicated)
Data is stored as: {name: ..., business: {version: 1, ...}}
"""
# Get existing record for merge
existing = await CoPilotUnderstanding.prisma().find_unique(
where={"userId": user_id}
)
# Get existing data structure or start fresh
existing_data: dict[str, Any] = {}
if existing and isinstance(existing.data, dict):
existing_data = dict(existing.data)
existing_business: dict[str, Any] = {}
if isinstance(existing_data.get("business"), dict):
existing_business = dict(existing_data["business"])
# Business fields (stored inside business object)
business_string_fields = [
"job_title",
"business_name",
"industry",
"business_size",
"user_role",
"additional_notes",
]
business_list_fields = [
"key_workflows",
"daily_activities",
"pain_points",
"bottlenecks",
"manual_tasks",
"automation_goals",
"current_software",
"existing_automation",
]
# Handle top-level name field
if input_data.user_name is not None:
existing_data["name"] = input_data.user_name
# Business string fields - overwrite if provided
for field in business_string_fields:
value = getattr(input_data, field)
if value is not None:
existing_business[field] = value
# Business list fields - merge with existing
for field in business_list_fields:
value = getattr(input_data, field)
if value is not None:
existing_list = _json_to_list(existing_business.get(field))
merged = _merge_lists(existing_list, value)
existing_business[field] = merged
# Set version and nest business data
existing_business["version"] = 1
existing_data["business"] = existing_business
# Upsert with the merged data
record = await CoPilotUnderstanding.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "data": SafeJson(existing_data)},
"update": {"data": SafeJson(existing_data)},
},
)
understanding = BusinessUnderstanding.from_db(record)
# Update cache with new understanding
await _set_cache(user_id, understanding)
return understanding
async def clear_business_understanding(user_id: str) -> bool:
"""Clear/delete business understanding for a user from both DB and cache."""
# Delete from cache first
await _delete_cache(user_id)
try:
await CoPilotUnderstanding.prisma().delete(where={"userId": user_id})
return True
except Exception:
# Record might not exist
return False
def format_understanding_for_prompt(understanding: BusinessUnderstanding) -> str:
"""Format business understanding as text for system prompt injection."""
sections = []
# User info section
user_info = []
if understanding.user_name:
user_info.append(f"Name: {understanding.user_name}")
if understanding.job_title:
user_info.append(f"Job Title: {understanding.job_title}")
if user_info:
sections.append("## User\n" + "\n".join(user_info))
# Business section
business_info = []
if understanding.business_name:
business_info.append(f"Company: {understanding.business_name}")
if understanding.industry:
business_info.append(f"Industry: {understanding.industry}")
if understanding.business_size:
business_info.append(f"Size: {understanding.business_size}")
if understanding.user_role:
business_info.append(f"Role Context: {understanding.user_role}")
if business_info:
sections.append("## Business\n" + "\n".join(business_info))
# Processes section
processes = []
if understanding.key_workflows:
processes.append(f"Key Workflows: {', '.join(understanding.key_workflows)}")
if understanding.daily_activities:
processes.append(
f"Daily Activities: {', '.join(understanding.daily_activities)}"
)
if processes:
sections.append("## Processes\n" + "\n".join(processes))
# Pain points section
pain_points = []
if understanding.pain_points:
pain_points.append(f"Pain Points: {', '.join(understanding.pain_points)}")
if understanding.bottlenecks:
pain_points.append(f"Bottlenecks: {', '.join(understanding.bottlenecks)}")
if understanding.manual_tasks:
pain_points.append(f"Manual Tasks: {', '.join(understanding.manual_tasks)}")
if pain_points:
sections.append("## Pain Points\n" + "\n".join(pain_points))
# Goals section
if understanding.automation_goals:
sections.append(
"## Automation Goals\n"
+ "\n".join(f"- {goal}" for goal in understanding.automation_goals)
)
# Current tools section
tools_info = []
if understanding.current_software:
tools_info.append(
f"Current Software: {', '.join(understanding.current_software)}"
)
if understanding.existing_automation:
tools_info.append(
f"Existing Automation: {', '.join(understanding.existing_automation)}"
)
if tools_info:
sections.append("## Current Tools\n" + "\n".join(tools_info))
# Additional notes
if understanding.additional_notes:
sections.append(f"## Additional Context\n{understanding.additional_notes}")
if not sections:
return ""
return "# User Business Context\n\n" + "\n\n".join(sections)

View File

@@ -7,6 +7,11 @@ from backend.api.features.library.db import (
list_library_agents,
)
from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.api.features.store.embeddings import (
backfill_missing_embeddings,
cleanup_orphaned_embeddings,
get_embedding_stats,
)
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
@@ -20,6 +25,7 @@ from backend.data.execution import (
get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs,
get_graph_execution,
get_graph_execution_meta,
get_graph_executions,
get_graph_executions_count,
@@ -57,6 +63,7 @@ from backend.data.notifications import (
get_user_notification_oldest_message_in_batch,
remove_notifications_from_batch,
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
@@ -140,6 +147,7 @@ class DatabaseManager(AppService):
get_child_graph_executions = _(get_child_graph_executions)
get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count)
get_graph_execution = _(get_graph_execution)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_execution = _(get_node_execution)
@@ -204,10 +212,18 @@ class DatabaseManager(AppService):
add_store_agent_to_library = _(add_store_agent_to_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
# Onboarding
increment_onboarding_runs = _(increment_onboarding_runs)
# Store
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(get_embedding_stats)
backfill_missing_embeddings = _(backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
# Summary data - async
get_user_execution_summary_data = _(get_user_execution_summary_data)
@@ -259,6 +275,11 @@ class DatabaseManagerClient(AppServiceClient):
get_store_agents = _(d.get_store_agents)
get_store_agent_details = _(d.get_store_agent_details)
# Store Embeddings
get_embedding_stats = _(d.get_embedding_stats)
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(d.cleanup_orphaned_embeddings)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -274,6 +295,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_settings = d.get_graph_settings
get_graph_execution = d.get_graph_execution
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_node_execution = d.get_node_execution
@@ -318,6 +340,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
add_store_agent_to_library = d.add_store_agent_to_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions
# Onboarding
increment_onboarding_runs = d.increment_onboarding_runs
# Store
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details

View File

@@ -114,6 +114,40 @@ utilization_gauge = Gauge(
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
@@ -144,6 +178,7 @@ async def execute_node(
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -211,6 +246,7 @@ async def execute_node(
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -508,6 +544,7 @@ class ExecutionProcessor:
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -530,6 +567,7 @@ class ExecutionProcessor:
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
@@ -575,6 +613,7 @@ class ExecutionProcessor:
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
@@ -611,6 +650,7 @@ class ExecutionProcessor:
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
@@ -922,6 +962,21 @@ class ExecutionProcessor:
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
@@ -982,6 +1037,7 @@ class ExecutionProcessor:
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)
@@ -1261,12 +1317,40 @@ class ExecutionProcessor:
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
@@ -1280,6 +1364,7 @@ class ExecutionProcessor:
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)

View File

@@ -0,0 +1,560 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma.enums import NotificationType
from backend.data.notifications import ZeroBalanceData
from backend.executor.manager import (
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
ExecutionProcessor,
clear_insufficient_funds_notifications,
)
from backend.util.exceptions import InsufficientBalanceError
from backend.util.test import SpinTestServer
async def async_iter(items):
"""Helper to create an async iterator from a list."""
for item in items:
yield item
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_sends_discord_alert_first_time(
server: SpinTestServer,
):
"""Test that the first insufficient funds notification sends a Discord alert."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72, # $0.72
amount=-714, # Attempting to spend $7.14
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate first-time notification (set returns True)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = True # Key was newly set
# Create mock database client
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify notification was queued
mock_queue_notif.assert_called_once()
notification_call = mock_queue_notif.call_args[0][0]
assert notification_call.type == NotificationType.ZERO_BALANCE
assert notification_call.user_id == user_id
assert isinstance(notification_call.data, ZeroBalanceData)
assert notification_call.data.current_balance == 72
# Verify Redis was checked with correct key pattern
expected_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
mock_redis_client.set.assert_called_once()
call_args = mock_redis_client.set.call_args
assert call_args[0][0] == expected_key
assert call_args[1]["nx"] is True
# Verify Discord alert was sent
mock_client.discord_system_alert.assert_called_once()
discord_message = mock_client.discord_system_alert.call_args[0][0]
assert "Insufficient Funds Alert" in discord_message
assert "test@example.com" in discord_message
assert "Test Agent" in discord_message
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_skips_duplicate_notifications(
server: SpinTestServer,
):
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Setup mocks
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to simulate duplicate notification (set returns False/None)
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.return_value = None # Key already existed
# Create mock database client
mock_db_client = MagicMock()
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was NOT queued (deduplication worked)
mock_queue_notif.assert_not_called()
# Verify Discord alert was NOT sent (deduplication worked)
mock_client.discord_system_alert.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
server: SpinTestServer,
):
"""Test that different agents for the same user get separate Discord alerts."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id_1 = "test-graph-111"
graph_id_2 = "test-graph-222"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch("backend.executor.manager.queue_notification"), patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
# Both calls return True (first time for each agent)
mock_redis_client.set.return_value = True
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# First agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_1,
e=error,
)
# Second agent notification
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id_2,
e=error,
)
# Verify Discord alerts were sent for both agents
assert mock_client.discord_system_alert.call_count == 2
# Verify Redis was called with different keys
assert mock_redis_client.set.call_count == 2
calls = mock_redis_client.set.call_args_list
assert (
calls[0][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_1}"
)
assert (
calls[1][0][0]
== f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id_2}"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
"""Test that clearing notifications removes all keys for a user."""
user_id = "test-user-123"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return some keys as an async iterator
mock_keys = [
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-2",
f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-3",
]
mock_redis_client.scan_iter.return_value = async_iter(mock_keys)
# delete is awaited, so use AsyncMock
mock_redis_client.delete = AsyncMock(return_value=3)
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify correct pattern was used
expected_pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
mock_redis_client.scan_iter.assert_called_once_with(match=expected_pattern)
# Verify delete was called with all keys
mock_redis_client.delete.assert_called_once_with(*mock_keys)
# Verify return value
assert result == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestServer):
"""Test clearing notifications when there are no keys to clear."""
user_id = "test-user-no-notifications"
with patch("backend.executor.manager.redis") as mock_redis_module:
mock_redis_client = MagicMock()
# get_redis_async is an async function, so we need AsyncMock for it
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
# Mock scan_iter to return no keys as an async iterator
mock_redis_client.scan_iter.return_value = async_iter([])
# Clear notifications
result = await clear_insufficient_funds_notifications(user_id)
# Verify delete was not called
mock_redis_client.delete.assert_not_called()
# Verify return value
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_clear_insufficient_funds_notifications_handles_redis_error(
server: SpinTestServer,
):
"""Test that clearing notifications handles Redis errors gracefully."""
user_id = "test-user-redis-error"
with patch("backend.executor.manager.redis") as mock_redis_module:
# Mock get_redis_async to raise an error
mock_redis_module.get_redis_async = AsyncMock(
side_effect=Exception("Redis connection failed")
)
# Clear notifications should not raise, just return 0
result = await clear_insufficient_funds_notifications(user_id)
# Verify it returned 0 (graceful failure)
assert result == 0
@pytest.mark.asyncio(loop_scope="session")
async def test_handle_insufficient_funds_continues_on_redis_error(
server: SpinTestServer,
):
"""Test that both email and Discord notifications are still sent when Redis fails."""
execution_processor = ExecutionProcessor()
user_id = "test-user-123"
graph_id = "test-graph-456"
error = InsufficientBalanceError(
message="Insufficient balance",
user_id=user_id,
balance=72,
amount=-714,
)
with patch(
"backend.executor.manager.queue_notification"
) as mock_queue_notif, patch(
"backend.executor.manager.get_notification_manager_client"
) as mock_get_client, patch(
"backend.executor.manager.settings"
) as mock_settings, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_settings.config.frontend_base_url = "https://test.com"
# Mock Redis to raise an error
mock_redis_client = MagicMock()
mock_redis_module.get_redis.return_value = mock_redis_client
mock_redis_client.set.side_effect = Exception("Redis connection error")
mock_db_client = MagicMock()
mock_graph_metadata = MagicMock()
mock_graph_metadata.name = "Test Agent"
mock_db_client.get_graph_metadata.return_value = mock_graph_metadata
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
# Test the insufficient funds handler
execution_processor._handle_insufficient_funds_notif(
db_client=mock_db_client,
user_id=user_id,
graph_id=graph_id,
e=error,
)
# Verify email notification was still queued despite Redis error
mock_queue_notif.assert_called_once()
# Verify Discord alert was still sent despite Redis error
mock_client.discord_system_alert.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_grant(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding GRANT credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-grant-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 1000, "transactionKey": "test-tx-key"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
# Create a concrete instance
credit_model = UserCredit()
# Call _add_transaction with GRANT type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500, # Positive amount
transaction_type=CreditTransactionType.GRANT,
is_active=True, # Active transaction
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestServer):
"""Test that _add_transaction clears notification flags when adding TOP_UP credits."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-topup-clear"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 2000, "transactionKey": "test-tx-key-2"}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter([])
mock_redis_client.delete = AsyncMock(return_value=0)
credit_model = UserCredit()
# Call _add_transaction with TOP_UP type (should clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=1000, # Positive amount
transaction_type=CreditTransactionType.TOP_UP,
is_active=True,
)
# Verify notification clearing was attempted
mock_redis_module.get_redis_async.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_inactive_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for inactive transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-inactive"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 500, "transactionKey": "test-tx-key-3"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with is_active=False (should NOT clear notifications)
await credit_model._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
is_active=False, # Inactive - pending Stripe payment
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_add_transaction_skips_clearing_for_usage_transaction(
server: SpinTestServer,
):
"""Test that _add_transaction does NOT clear notifications for USAGE transactions."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-usage"
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
"backend.executor.manager.redis"
) as mock_redis_module:
# Mock the query to return a successful transaction
mock_query.return_value = [{"balance": 400, "transactionKey": "test-tx-key-4"}]
# Mock async Redis
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
credit_model = UserCredit()
# Call _add_transaction with USAGE type (spending, should NOT clear)
await credit_model._add_transaction(
user_id=user_id,
amount=-100, # Negative - spending credits
transaction_type=CreditTransactionType.USAGE,
is_active=True,
)
# Verify notification clearing was NOT called
mock_redis_module.get_redis_async.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_clears_notifications(server: SpinTestServer):
"""Test that _enable_transaction clears notification flags when enabling a TOP_UP."""
from prisma.enums import CreditTransactionType
from backend.data.credit import UserCredit
user_id = "test-user-enable"
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
"backend.data.credit.query_raw_with_schema"
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
# Mock finding the pending transaction
mock_transaction = MagicMock()
mock_transaction.amount = 1000
mock_transaction.type = CreditTransactionType.TOP_UP
mock_credit_tx.prisma.return_value.find_first = AsyncMock(
return_value=mock_transaction
)
# Mock the query to return updated balance
mock_query.return_value = [{"balance": 1500}]
# Mock async Redis for notification clearing
mock_redis_client = MagicMock()
mock_redis_module.get_redis_async = AsyncMock(return_value=mock_redis_client)
mock_redis_client.scan_iter.return_value = async_iter(
[f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:graph-1"]
)
mock_redis_client.delete = AsyncMock(return_value=1)
credit_model = UserCredit()
# Call _enable_transaction (simulates Stripe checkout completion)
from backend.util.json import SafeJson
await credit_model._enable_transaction(
transaction_key="cs_test_123",
user_id=user_id,
metadata=SafeJson({"payment": "completed"}),
)
# Verify notification clearing was called
mock_redis_module.get_redis_async.assert_called_once()
mock_redis_client.scan_iter.assert_called_once_with(
match=f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
)

View File

@@ -1,4 +1,5 @@
import logging
from unittest.mock import AsyncMock, patch
import fastapi.responses
import pytest
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
logger = logging.getLogger(__name__)
@pytest.fixture(scope="session", autouse=True)
def mock_embedding_functions():
"""Mock embedding functions for all tests to avoid database/API dependencies."""
with patch(
"backend.api.features.store.db.ensure_embedding",
new_callable=AsyncMock,
return_value=True,
):
yield
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
logger.info(f"Creating graph for user {u.id}")
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)

View File

@@ -2,6 +2,7 @@ import asyncio
import logging
import os
import threading
import time
import uuid
from enum import Enum
from typing import Optional
@@ -27,7 +28,7 @@ from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_runs
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils
from backend.monitoring import (
NotificationJobArgs,
@@ -37,7 +38,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts,
report_late_executions,
)
from backend.util.clients import get_scheduler_client
from backend.util.clients import get_database_manager_client, get_scheduler_client
from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import (
GraphNotFoundError,
@@ -156,7 +157,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data,
graph_credentials_inputs=args.input_credentials,
)
await increment_runs(args.user_id)
await increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -254,6 +255,114 @@ def execution_accuracy_alerts():
return report_execution_accuracy_alerts()
def ensure_embeddings_coverage():
"""
Ensure all content types (store agents, blocks, docs) have embeddings for search.
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
Missing embeddings = content invisible in hybrid search.
Schedule: Runs every 6 hours (balanced between coverage and API costs).
- Catches new content added between scheduled runs
- Batch size 10 per content type: gradual processing to avoid rate limits
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
"""
db_client = get_database_manager_client()
stats = db_client.get_embedding_stats()
# Check for error from get_embedding_stats() first
if "error" in stats:
logger.error(
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
)
return {
"backfill": {"processed": 0, "success": 0, "failed": 0},
"cleanup": {"deleted": 0},
"error": stats["error"],
}
# Extract totals from new stats structure
totals = stats.get("totals", {})
without_embeddings = totals.get("without_embeddings", 0)
coverage_percent = totals.get("coverage_percent", 0)
total_processed = 0
total_success = 0
total_failed = 0
if without_embeddings == 0:
logger.info("All content has embeddings, skipping backfill")
else:
# Log per-content-type stats for visibility
by_type = stats.get("by_type", {})
for content_type, type_stats in by_type.items():
if type_stats.get("without_embeddings", 0) > 0:
logger.info(
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
f"({type_stats['coverage_percent']}% coverage)"
)
logger.info(
f"Total: {without_embeddings} items without embeddings "
f"({coverage_percent}% coverage) - processing all"
)
# Process in batches until no more missing embeddings
while True:
result = db_client.backfill_missing_embeddings(batch_size=10)
total_processed += result["processed"]
total_success += result["success"]
total_failed += result["failed"]
if result["processed"] == 0:
# No more missing embeddings
break
if result["success"] == 0 and result["processed"] > 0:
# All attempts in this batch failed - stop to avoid infinite loop
logger.error(
f"All {result['processed']} embedding attempts failed - stopping backfill"
)
break
# Small delay between batches to avoid rate limits
time.sleep(1)
logger.info(
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
f"{total_failed} failed"
)
# Clean up orphaned embeddings for blocks and docs
logger.info("Running cleanup for orphaned embeddings (blocks/docs)...")
cleanup_result = db_client.cleanup_orphaned_embeddings()
cleanup_totals = cleanup_result.get("totals", {})
cleanup_deleted = cleanup_totals.get("deleted", 0)
if cleanup_deleted > 0:
logger.info(f"Cleanup completed: deleted {cleanup_deleted} orphaned embeddings")
by_type = cleanup_result.get("by_type", {})
for content_type, type_result in by_type.items():
if type_result.get("deleted", 0) > 0:
logger.info(
f"{content_type}: deleted {type_result['deleted']} orphaned embeddings"
)
else:
logger.info("Cleanup completed: no orphaned embeddings found")
return {
"backfill": {
"processed": total_processed,
"success": total_success,
"failed": total_failed,
},
"cleanup": {
"deleted": cleanup_deleted,
},
}
# Monitoring functions are now imported from monitoring module
@@ -475,11 +584,36 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# Embedding Coverage - Every 6 hours
# Ensures all approved agents have embeddings for hybrid search
# Critical: missing embeddings = agents invisible in search
self.scheduler.add_job(
ensure_embeddings_coverage,
id="ensure_embeddings_coverage",
trigger="interval",
hours=6,
replace_existing=True,
max_instances=1, # Prevent overlapping runs
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
self.scheduler.start()
# Run embedding backfill immediately on startup
# This ensures blocks/docs are searchable right away, not after 6 hours
# Safe to run on multiple pods - uses upserts and checks for existing embeddings
if self.register_system_tasks:
logger.info("Running embedding backfill on startup...")
try:
result = ensure_embeddings_coverage()
logger.info(f"Startup embedding backfill complete: {result}")
except Exception as e:
logger.error(f"Startup embedding backfill failed: {e}")
# Don't fail startup - the scheduled job will retry later
# Keep the service running since BackgroundScheduler doesn't block
super().run_service()
@@ -632,6 +766,11 @@ class Scheduler(AppService):
"""Manually trigger execution accuracy alert checking."""
return execution_accuracy_alerts()
@expose
def execute_ensure_embeddings_coverage(self):
"""Manually trigger embedding backfill for approved store agents."""
return ensure_embeddings_coverage()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db
from backend.data.block import (
Block,
@@ -31,7 +32,6 @@ from backend.data.execution import (
GraphExecutionStats,
GraphExecutionWithNodes,
NodesInputMasks,
get_graph_execution,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
@@ -239,14 +239,19 @@ async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> dict[str, dict[str, str]]:
) -> tuple[dict[str, dict[str, str]], set[str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors.
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Returns:
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
@@ -256,27 +261,46 @@ async def _validate_node_input_credentials(
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
credentials_meta = credentials_meta_type.model_validate(
node_input_mask[field_name]
)
field_value = node_input_mask[field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
# Missing credentials
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
@@ -287,6 +311,7 @@ async def _validate_node_input_credentials(
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
@@ -313,7 +338,19 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
return credential_errors
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
def make_node_credentials_input_map(
@@ -355,21 +392,25 @@ async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> Mapping[str, Mapping[str, str]]:
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
"""
Validate graph including credentials and return structured errors per node.
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Returns:
dict[node_id, dict[field_name, error_message]]: Validation errors per node
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors
node_credential_input_errors = await _validate_node_input_credentials(
graph, user_id, nodes_input_masks
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
)
# Merge credential errors with structural errors
@@ -378,7 +419,7 @@ async def validate_graph_with_credentials(
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors
return node_input_errors, nodes_to_skip
async def _construct_starting_node_execution_input(
@@ -386,7 +427,7 @@ async def _construct_starting_node_execution_input(
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> list[tuple[str, BlockInput]]:
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
@@ -400,11 +441,14 @@ async def _construct_starting_node_execution_input(
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
"""
# Use new validation function that includes credentials
validation_errors = await validate_graph_with_credentials(
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
@@ -445,7 +489,7 @@ async def _construct_starting_node_execution_input(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input
return nodes_input, nodes_to_skip
async def validate_and_construct_node_execution_input(
@@ -456,7 +500,7 @@ async def validate_and_construct_node_execution_input(
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks]:
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
@@ -473,6 +517,7 @@ async def validate_and_construct_node_execution_input(
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
@@ -514,14 +559,16 @@ async def validate_and_construct_node_execution_input(
nodes_input_masks or {},
)
starting_nodes_input = await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
)
return graph, starting_nodes_input, nodes_input_masks
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
def _merge_nodes_input_masks(
@@ -762,13 +809,14 @@ async def add_graph_execution(
edb = execution_db
udb = user_db
gdb = graph_db
odb = onboarding_db
else:
edb = udb = gdb = get_database_manager_async_client()
edb = udb = gdb = odb = get_database_manager_async_client()
# Get or create the graph execution
if graph_exec_id:
# Resume existing execution
graph_exec = await get_graph_execution(
graph_exec = await edb.get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=True,
@@ -779,6 +827,9 @@ async def add_graph_execution(
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
@@ -787,7 +838,7 @@ async def add_graph_execution(
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks = (
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
@@ -836,10 +887,12 @@ async def add_graph_execution(
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
# Publish to execution queue for executor to pick up
exec_queue = await get_async_execution_queue()
await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
@@ -848,14 +901,12 @@ async def add_graph_execution(
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
# Update execution status to QUEUED
graph_exec.status = ExecutionStatus.QUEUED
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=graph_exec.status,
)
await get_async_execution_event_bus().publish(graph_exec)
return graph_exec
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:
@@ -876,6 +927,24 @@ async def add_graph_execution(
)
raise
try:
await get_async_execution_event_bus().publish(graph_exec)
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
except Exception as e:
logger.error(
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
)
try:
await odb.increment_onboarding_runs(user_id)
logger.info(
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
)
except Exception as e:
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
return graph_exec
# ============ Execution Output Helpers ============ #

View File

@@ -367,10 +367,13 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
)
# Setup mock returns
# The function returns (graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip)
nodes_to_skip: set[str] = set()
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip,
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
@@ -456,3 +459,212 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
# Both executions should succeed (though they create different objects)
assert result1 == mock_graph_exec
assert result2 == mock_graph_exec_2
# ============================================================================
# Tests for Optional Credentials Feature
# ============================================================================
@pytest.mark.asyncio
async def test_validate_node_input_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns nodes_to_skip set
for nodes with credentials_optional=True and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=True
mock_node = mocker.MagicMock()
mock_node.id = "node-with-optional-creds"
mock_node.credentials_optional = True
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@pytest.mark.asyncio
async def test_validate_node_input_credentials_required_missing_creds_error(
mocker: MockerFixture,
):
"""
Test that _validate_node_input_credentials returns errors
for nodes with credentials_optional=False and missing credentials.
"""
from backend.executor.utils import _validate_node_input_credentials
# Create a mock node with credentials_optional=False (required)
mock_node = mocker.MagicMock()
mock_node.id = "node-with-required-creds"
mock_node.credentials_optional = False
mock_node.input_default = {} # No credentials configured
# Create a mock block with credentials field
mock_block = mocker.MagicMock()
mock_credentials_field_type = mocker.MagicMock()
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_node.block = mock_block
# Create mock graph
mock_graph = mocker.MagicMock()
mock_graph.nodes = [mock_node]
# Call the function
errors, nodes_to_skip = await _validate_node_input_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Node should be in errors, not in nodes_to_skip
assert mock_node.id in errors
assert "credentials" in errors[mock_node.id]
assert "required" in errors[mock_node.id]["credentials"].lower()
assert mock_node.id not in nodes_to_skip
@pytest.mark.asyncio
async def test_validate_graph_with_credentials_returns_nodes_to_skip(
mocker: MockerFixture,
):
"""
Test that validate_graph_with_credentials returns nodes_to_skip set
from _validate_node_input_credentials.
"""
from backend.executor.utils import validate_graph_with_credentials
# Mock _validate_node_input_credentials to return specific values
mock_validate = mocker.patch(
"backend.executor.utils._validate_node_input_credentials"
)
expected_errors = {"node1": {"field": "error"}}
expected_nodes_to_skip = {"node2", "node3"}
mock_validate.return_value = (expected_errors, expected_nodes_to_skip)
# Mock GraphModel with validate_graph_get_errors method
mock_graph = mocker.MagicMock()
mock_graph.validate_graph_get_errors.return_value = {}
# Call the function
errors, nodes_to_skip = await validate_graph_with_credentials(
graph=mock_graph,
user_id="test-user-id",
nodes_input_masks=None,
)
# Verify nodes_to_skip is passed through
assert nodes_to_skip == expected_nodes_to_skip
assert "node1" in errors
@pytest.mark.asyncio
async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
"""
Test that add_graph_execution properly passes nodes_to_skip
to the graph execution entry.
"""
from backend.data.execution import GraphExecutionWithNodes
from backend.executor.utils import add_graph_execution
# Mock data
graph_id = "test-graph-id"
user_id = "test-user-id"
inputs = {"test_input": "test_value"}
graph_version = 1
# Mock the graph object
mock_graph = mocker.MagicMock()
mock_graph.version = graph_version
# Starting nodes and masks
starting_nodes_input = [("node1", {"input1": "value1"})]
compiled_nodes_input_masks = {}
nodes_to_skip = {"skipped-node-1", "skipped-node-2"}
# Mock the graph execution object
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
mock_graph_exec.id = "execution-id-123"
mock_graph_exec.node_executions = []
# Track what's passed to to_graph_execution_entry
captured_kwargs = {}
def capture_to_entry(**kwargs):
captured_kwargs.update(kwargs)
return mocker.MagicMock()
mock_graph_exec.to_graph_execution_entry.side_effect = capture_to_entry
# Setup mocks
mock_validate = mocker.patch(
"backend.executor.utils.validate_and_construct_node_execution_input"
)
mock_edb = mocker.patch("backend.executor.utils.execution_db")
mock_prisma = mocker.patch("backend.executor.utils.prisma")
mock_udb = mocker.patch("backend.executor.utils.user_db")
mock_gdb = mocker.patch("backend.executor.utils.graph_db")
mock_get_queue = mocker.patch("backend.executor.utils.get_async_execution_queue")
mock_get_event_bus = mocker.patch(
"backend.executor.utils.get_async_execution_event_bus"
)
# Setup returns - include nodes_to_skip in the tuple
mock_validate.return_value = (
mock_graph,
starting_nodes_input,
compiled_nodes_input_masks,
nodes_to_skip, # This should be passed through
)
mock_prisma.is_connected.return_value = True
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
return_value=mock_graph_exec
)
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
mock_user = mocker.MagicMock()
mock_user.timezone = "UTC"
mock_settings = mocker.MagicMock()
mock_settings.human_in_the_loop_safe_mode = True
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
mock_get_queue.return_value = mocker.AsyncMock()
mock_get_event_bus.return_value = mocker.MagicMock(publish=mocker.AsyncMock())
# Call the function
await add_graph_execution(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
graph_version=graph_version,
)
# Verify nodes_to_skip was passed to to_graph_execution_entry
assert "nodes_to_skip" in captured_kwargs
assert captured_kwargs["nodes_to_skip"] == nodes_to_skip

View File

@@ -245,6 +245,21 @@ DEFAULT_CREDENTIALS = [
webshare_proxy_credentials,
]
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
# Set of providers that have system credentials available
SYSTEM_PROVIDERS = {cred.provider for cred in DEFAULT_CREDENTIALS}
def is_system_credential(credential_id: str) -> bool:
"""Check if a credential ID belongs to a system-managed credential."""
return credential_id in SYSTEM_CREDENTIAL_IDS
def is_system_provider(provider: str) -> bool:
"""Check if a provider has system-managed credentials available."""
return provider in SYSTEM_PROVIDERS
class IntegrationCredentialsStore:
def __init__(self):

View File

@@ -8,6 +8,7 @@ from .discord import DiscordOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .reddit import RedditOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
@@ -20,6 +21,7 @@ _ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
RedditOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]

View File

@@ -0,0 +1,208 @@
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from backend.util.settings import Settings
settings = Settings()
class RedditOAuthHandler(BaseOAuthHandler):
"""
Reddit OAuth 2.0 handler.
Based on the documentation at:
- https://github.com/reddit-archive/reddit/wiki/OAuth2
Notes:
- Reddit requires `duration=permanent` to get refresh tokens
- Access tokens expire after 1 hour (3600 seconds)
- Reddit requires HTTP Basic Auth for token requests
- Reddit requires a unique User-Agent header
"""
PROVIDER_NAME = ProviderName.REDDIT
DEFAULT_SCOPES: ClassVar[list[str]] = [
"identity", # Get username, verify auth
"read", # Access posts and comments
"submit", # Submit new posts and comments
"edit", # Edit own posts and comments
"history", # Access user's post history
"privatemessages", # Access inbox and send private messages
"flair", # Access and set flair on posts/subreddits
]
AUTHORIZE_URL = "https://www.reddit.com/api/v1/authorize"
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
USERNAME_URL = "https://oauth.reddit.com/api/v1/me"
REVOKE_URL = "https://www.reddit.com/api/v1/revoke_token"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Reddit OAuth 2.0 authorization URL"""
scopes = self.handle_default_scopes(scopes)
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(scopes),
"state": state,
"duration": "permanent", # Required for refresh tokens
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
scopes = self.handle_default_scopes(scopes)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
}
# Reddit requires HTTP Basic Auth for token requests
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token exchange failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None, # Reddit refresh tokens don't expire
scopes=scopes,
)
async def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {
"Authorization": f"Bearer {access_token}",
"User-Agent": settings.config.reddit_user_agent,
}
response = await Requests().get(self.USERNAME_URL, headers=headers)
if not response.ok:
raise ValueError(f"Failed to get Reddit username: {response.status}")
data = response.json()
return data.get("name", "unknown")
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.TOKEN_URL, headers=headers, data=data, auth=auth
)
if not response.ok:
error_text = response.text()
raise ValueError(
f"Reddit token refresh failed: {response.status} - {error_text}"
)
tokens = response.json()
if "error" in tokens:
raise ValueError(f"Reddit OAuth error: {tokens.get('error')}")
username = await self._get_username(tokens["access_token"])
# Reddit may or may not return a new refresh token
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token:
refresh_token: SecretStr | None = SecretStr(new_refresh_token)
elif credentials.refresh_token:
# Keep the existing refresh token
refresh_token = credentials.refresh_token
else:
refresh_token = None
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
username=username,
access_token=tokens["access_token"],
refresh_token=refresh_token,
access_token_expires_at=int(time.time()) + tokens.get("expires_in", 3600),
refresh_token_expires_at=None,
scopes=credentials.scopes,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": settings.config.reddit_user_agent,
}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = await Requests().post(
self.REVOKE_URL, headers=headers, data=data, auth=auth
)
# Reddit returns 204 No Content on successful revocation
return response.ok

View File

@@ -16,7 +16,7 @@ import pickle
import threading
import time
from dataclasses import dataclass
from functools import wraps
from functools import cache, wraps
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
from redis import ConnectionPool, Redis
@@ -38,29 +38,34 @@ settings = Settings()
# maxmemory 2gb # Set memory limit (adjust based on your needs)
# save "" # Disable persistence if using Redis purely for caching
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
_cache_pool: ConnectionPool | None = None
@conn_retry("Redis", "Acquiring cache connection pool")
@cache
def _get_cache_pool() -> ConnectionPool:
"""Get or create a connection pool for cache operations."""
global _cache_pool
if _cache_pool is None:
_cache_pool = ConnectionPool(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=False, # Binary mode for pickle
max_connections=50,
socket_keepalive=True,
socket_connect_timeout=5,
retry_on_timeout=True,
)
return _cache_pool
"""Get or create a connection pool for cache operations (lazy, thread-safe)."""
return ConnectionPool(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=False, # Binary mode for pickle
max_connections=50,
socket_keepalive=True,
socket_connect_timeout=5,
retry_on_timeout=True,
)
redis = Redis(connection_pool=_get_cache_pool())
@cache
@conn_retry("Redis", "Acquiring cache connection")
def _get_redis() -> Redis:
"""
Get the lazily-initialized Redis client for shared cache operations.
Uses @cache for thread-safe singleton behavior - connection is only
established when first accessed, allowing services that only use
in-memory caching to work without Redis configuration.
"""
r = Redis(connection_pool=_get_cache_pool())
r.ping() # Verify connection
return r
@dataclass
@@ -179,9 +184,9 @@ def cached(
try:
if refresh_ttl_on_get:
# Use GETEX to get value and refresh expiry atomically
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds)
else:
cached_bytes = redis.get(redis_key)
cached_bytes = _get_redis().get(redis_key)
if cached_bytes and isinstance(cached_bytes, bytes):
return pickle.loads(cached_bytes)
@@ -195,7 +200,7 @@ def cached(
"""Set value in Redis with TTL."""
try:
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
redis.setex(redis_key, ttl_seconds, pickled_value)
_get_redis().setex(redis_key, ttl_seconds, pickled_value)
except Exception as e:
logger.error(
f"Redis error storing cache for {target_func.__name__}: {e}"
@@ -333,14 +338,18 @@ def cached(
if pattern:
# Clear entries matching pattern
keys = list(
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
_get_redis().scan_iter(
f"cache:{target_func.__name__}:{pattern}"
)
)
else:
# Clear all cache keys
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
keys = list(
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
)
if keys:
pipeline = redis.pipeline()
pipeline = _get_redis().pipeline()
for key in keys:
pipeline.delete(key)
pipeline.execute()
@@ -355,7 +364,9 @@ def cached(
def cache_info() -> dict[str, int | None]:
if shared_cache:
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
cache_keys = list(
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
)
return {
"size": len(cache_keys),
"maxsize": None, # Redis manages its own size
@@ -373,10 +384,8 @@ def cached(
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_key = _make_redis_key(key, target_func.__name__)
if redis.exists(redis_key):
redis.delete(redis_key)
return True
return False
deleted_count = cast(int, _get_redis().delete(redis_key))
return deleted_count > 0
else:
if key in cache_storage:
del cache_storage[key]

View File

@@ -10,6 +10,7 @@ from backend.util.settings import Settings
settings = Settings()
if TYPE_CHECKING:
from openai import AsyncOpenAI
from supabase import AClient, Client
from backend.data.execution import (
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
)
# ============ OpenAI Client ============ #
@cached(ttl_seconds=3600)
def get_openai_client() -> "AsyncOpenAI | None":
"""
Get a process-cached async OpenAI client for embeddings.
Returns None if API key is not configured.
"""
from openai import AsyncOpenAI
api_key = settings.secrets.openai_internal_api_key
if not api_key:
return None
return AsyncOpenAI(api_key=api_key)
# ============ Notification Queue Helpers ============ #

View File

@@ -264,7 +264,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
)
reddit_user_agent: str = Field(
default="AutoGPT:1.0 (by /u/autogpt)",
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
)
@@ -658,6 +658,14 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# Langfuse prompt management
langfuse_public_key: str = Field(default="", description="Langfuse public key")
langfuse_secret_key: str = Field(default="", description="Langfuse secret key")
langfuse_host: str = Field(
default="https://cloud.langfuse.com", description="Langfuse host URL"
)
# Add more secret fields as needed
model_config = SettingsConfigDict(
env_file=".env",

Some files were not shown because too many files have changed in this diff Show More