Compare commits

..

38 Commits

Author SHA1 Message Date
Lluis Agusti
ee11623735 chore: vercel preview 2025-07-14 16:14:58 +04:00
Lluis Agusti
0bb160e930 chore: generate 2025-07-14 15:38:33 +04:00
Lluis Agusti
81a09738dc chore: CAPTCHA 2025-07-14 15:23:39 +04:00
Lluis Agusti
6feedafd7d Merge 'dev' into 'feat/agent-notifications' 2025-07-14 15:05:36 +04:00
Lluis Agusti
547da633c4 Merge 'dev' into 'feat/agent-notifications' 2025-07-14 14:41:01 +04:00
Ubbe
fde3533943 fix(frontend): logout pages design adjustments (#10342)
## Changes 🏗️

- Put `Continue with Google` button below the other button on the forms
( _to confirm with design_ )
- Ensure some vertical spacing so the forms don't end touching the
header on small screens
- Apply style adjustments asked by design on navbar links

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Check the above

### For configuration changes:

None
2025-07-14 10:28:09 +00:00
Ubbe
a789f87734 fix(frontend): disable Cloudflare on Vercel previews (#10354)
## Changes 🏗️

Disable the Cloudflare check:

<img width="600" height="861" alt="Screenshot 2025-07-11 at 18 51 46"
src="https://github.com/user-attachments/assets/792ecca0-967e-4cef-a562-789125452d2f"
/>

On Vercel previews, so we can use previews for testing Front-end only
changes.

Vercel previews have dynamically generated URLs:
```
https://{branch}-{commit}-significant-gravitas.vercel.app/login
```

So if Cloudflare does not support URL wildcards we will neeed to do this
🙇🏽 ( _as an experiment_ )

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] You can login on the preview
  
### For configuration changes:

None
2025-07-14 10:27:56 +00:00
Abhimanyu Yadav
0b6e46d363 fix(frontend): fix my agent count in the library (#10357)
Currently, my agents count is showing the initial agent count loads on
the library and then adding more agents after pagination.

### Changes 🏗️
- I’ve used `total_items` inside the pagination response and shown the
correct result.

### Demo

https://github.com/user-attachments/assets/b9a2cf18-c9fc-42f8-b0d4-3f8a7ad3cbc5


### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Manually test everything, and it works fine.
2025-07-14 10:20:33 +00:00
Muhammad Ehsan
6ffe57c3df fix(docs): Updated Discord Badge in README for Better Visibility (#10360)
### Motivation 💡

The previous Discord badge in the README used `dcbadge.vercel.app`,
which often fails to render correctly and displays an invalid or broken
badge.

### Changes 🛠️

- Replaced the broken badge with a `shields.io` Discord badge that is
visually consistent with the Twitter badge
- Ensures clearer visual guidance and a more professional appearance

### Notes ✏️

This PR only updates the `README.md` no frontend, backend, or
configuration files are touched. This change improves the aesthetics and
onboarding experience for new contributors.

Screenshot of the issue:
<img width="405" height="47" alt="Screenshot 2025-07-12 175316"
src="https://github.com/user-attachments/assets/41f7355c-f795-4163-855f-3d01f2478dd7"
/>

---------

Co-authored-by: Ubbe <hi@ubbe.dev>
Co-authored-by: Bently <Github@bentlybro.com>
Co-authored-by: Bently <tomnoon9@gmail.com>
2025-07-14 09:56:32 +00:00
Bently
3ca0d04ea0 fix(readme): Removes MIT icon from readme (#10366)
This PR simply removes the MIT Icon from the main README.md
2025-07-14 09:40:29 +00:00
Zamil Majdy
c2eea593c0 fix(backend): Include node execution steps and cost of sub-graph execution (#10328)
## Summary
This PR enhances the node execution stats tracking system to properly
handle nested graph executions and additional cost/step metrics:

- **Add extra_cost and extra_steps fields** to `NodeExecutionStats`
model for tracking additional metrics from sub-graphs
- **Update AgentExecutorBlock** to merge nested execution stats from
sub-graphs into the parent execution
- **Fix stats update mechanism** in `execute_node` to use in-place
updates instead of `model_copy` for better performance
- **Add proper tracking** of extra costs and steps in graph execution
stats aggregation

## Changes Made
- Modified `backend/backend/data/model.py` to add `extra_cost` and
`extra_steps` fields
- Updated `backend/backend/blocks/agent.py` to merge stats from nested
graph executions
- Fixed `backend/backend/executor/manager.py` to properly update
execution stats and aggregate extra metrics

## Test Plan
- [x] Verify that nested graph executions properly propagate their stats
to parent graphs
- [x] Test that extra costs and steps are correctly tracked and
aggregated
- [x] Ensure debug logging provides useful information for monitoring
- [x] Run existing tests to ensure no regressions
- [x] Test with multi-level nested agent graphs

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-07-14 09:01:15 +00:00
Lluis Agusti
6d13dfc688 chore: empty state 2025-07-11 19:56:12 +04:00
Reinier van der Leer
36f5f24333 feat(platform/builder): Builder credentials support + UX improvements (#10323)
- Resolves #10313
- Resolves #10333

Before:


https://github.com/user-attachments/assets/a105b2b0-a90b-4bc6-89da-bef3f5a5fa1f
- No credentials input
- Stuttery experience when panning or zooming the viewport

After:


https://github.com/user-attachments/assets/f58d7864-055f-4e1c-a221-57154467c3aa
- Pretty much the same UX as in the Library, with fully-fledged
credentials input support
- Much smoother when moving around the canvas

### Changes 🏗️

Frontend:
- Add credentials input support to Run UX in Builder
  - Pass run inputs instead of storing them on the input nodes
- Re-implement `RunnerInputUI` using `AgentRunDetailsView`; rename to
`RunnerInputDialog`
    - Make `AgentRunDraftView` more flexible
    - Remove `RunnerInputList`, `RunnerInputBlock`
- Make moving around in the Builder *smooooth* by reducing unnecessary
re-renders
  - Clean up and partially re-write bead management logic
- Replace `request*` fire-and-forget methods in `useAgentGraph` with
direct action async callbacks
- Clean up run input UI components
  - Simplify `RunnerUIWrapper`
- Add `isEmpty` utility function in `@/lib/utils` (expanding on
`_.isEmpty`)
- Fix default value handling in `TypeBasedInput` (**Note:** after all
the changes I've made I'm not sure this is still necessary)
- Improve & clean up Builder test implementations

Backend + API:
- Fix front-end `Node`, `GraphMeta`, and `Block` types
- Small refactor of `Graph` to match naming of some `LibraryAgent`
attributes
- Fix typing of `list_graphs`,
`get_graph_meta_by_store_listing_version_id` endpoints
  - Add `GraphMeta` model and `GraphModel.meta()` shortcut
- Move `POST /library/agents/{library_agent_id}/setup-trigger` to `POST
/library/presets/setup-trigger`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Test the new functionality in the Builder:
    - [x] Running an agent with (credentials) inputs from the builder
      - [x] Beads behave correctly
    - [x] Running an agent without any inputs from the builder
    - [x] Scheduling an agent from the builder
    - [x] Adding and searching blocks in the block menu
- [x] Test that all existing `AgentRunDraftView` functionality in the
Library still works the same
    - [x] Run an agent
    - [x] Schedule an agent
    - [x] View past runs
- [x] Run an agent with inputs, then edit the agent's inputs and view
the agent in the Library (should be fine)
2025-07-11 15:46:06 +00:00
Lluis Agusti
d0d498fa66 chore: undo 2025-07-11 19:44:27 +04:00
Lluis Agusti
c843dee317 Merge 'dev' into 'feat/agent-notifications' 2025-07-11 19:36:01 +04:00
Lluis Agusti
db969c1bf8 chore: rename 2025-07-11 19:35:53 +04:00
Lluis Agusti
690fac91e4 chore: lint 2025-07-11 18:58:13 +04:00
Reinier van der Leer
309114a727 Merge commit from fork 2025-07-11 16:43:03 +02:00
Lluis Agusti
5368fdc998 chore: tests 2025-07-11 18:31:44 +04:00
Lluis Agusti
b9d293f181 chore: updates 2025-07-11 18:15:30 +04:00
Lluis Agusti
acbcef77b2 Merge 'dev' into 'feat/agent-notifications' 2025-07-11 17:40:50 +04:00
Zamil Majdy
4ffb99bfb0 feat(backend): Add block error rate monitoring and Discord alerts (#10332)
## Summary

This PR adds a simple block error rate monitoring system that runs every
24 hours (configurable) and sends Discord alerts when blocks exceed the
error rate threshold.

## Changes Made

**Modified Files:**
- `backend/executor/scheduler.py` - Added `report_block_error_rates`
function and scheduled job
- `backend/util/settings.py` - Added configuration options
- `backend/.env.example` - Added environment variable examples
- Refactor scheduled job logics in scheduler.py into seperate files

## Configuration

```bash
# Block Error Rate Monitoring
BLOCK_ERROR_RATE_THRESHOLD=0.5  # 50% error rate threshold
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400  # 24 hours
```

## How It Works

1. **Scheduled Job**: Runs every 24 hours (configurable via
`BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS`)
2. **Error Rate Calculation**: Queries last 24 hours of node executions
and calculates error rates per block
3. **Threshold Check**: Alerts on blocks with ≥50% error rate
(configurable via `BLOCK_ERROR_RATE_THRESHOLD`)
4. **Discord Alert**: Sends alert to Discord using existing
`discord_system_alert` function
5. **Manual Execution**: Available via
`execute_report_block_error_rates()` scheduler client method

## Alert Format

```
Block Error Rate Alert:
🚨 Block 'DeprecatedGPT3Block' has 75.0% error rate (75/100) in the last 24 hours
🚨 Block 'BrokenImageBlock' has 60.0% error rate (30/50) in the last 24 hours
```

## Testing

Can be tested manually via:
```python
from backend.executor.scheduler import SchedulerClient
client = SchedulerClient()
result = client.execute_report_block_error_rates()
```

## Implementation Notes

- Follows the same pattern as `report_late_executions` function
- Only checks blocks with ≥10 executions to avoid noise
- Uses existing Discord notification infrastructure
- Configurable threshold and check interval
- Proper error handling and logging

## Test plan

- [x] Verify configuration loads correctly
- [x] Test error rate calculation with existing database
- [x] Confirm Discord integration works
- [x] Test manual execution via scheduler client
- [x] Verify scheduled job runs correctly

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude AI <claude@anthropic.com>
Co-authored-by: Claude <noreply@anthropic.com>
2025-07-10 21:56:58 +00:00
Lluis Agusti
e902848e04 chore: fix 2025-07-10 23:06:07 +04:00
Lluis Agusti
cd917ec919 Merge 'dev' into 'feat/agent-notifications' 2025-07-10 22:48:24 +04:00
Ubbe
5741331250 feat(frontend): logged out pages UI updates (#10314)
## Changes 🏗️

<img width="800" alt="Screenshot 2025-07-07 at 13 16 44"
src="https://github.com/user-attachments/assets/0d404958-d4c9-454d-b71a-9dd677fe0fdc"
/>

<img width="800" alt="Screenshot 2025-07-07 at 13 17 08"
src="https://github.com/user-attachments/assets/1142f6d5-a6af-485d-b42e-98afd26de3ed"
/>

Update the UI of the logged-out pages ( _login, signup,
reset-password..._ ) using the new Design System components, so the app
starts to look a bit more cohesive 💆🏽

Some notes:

- I refactored the `<AuthCard />` components a bit to be easier to use
- I split the render from hook login on login/signup
- I added a couple of modals to improve the UX when logging in with
Google or using non-whitelisted emails
  -  _see below my comments for more context_ 
- When there are API errors, they are shown in a toast to prevent the
layout of the form from jumping
- When using the components in the UI, an issue with border-radius, see
comments for an explanation




## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Logout on the platform
  - [x] Check the updated Login/Signup/Reset password pages
  - [x] The UI looks good and is consistent
  - [x]  The forms work as expected
2025-07-10 18:27:24 +00:00
Ubbe
2fda8dfd32 feat(frontend): new navbar design (#10341)
## Changes 🏗️

<img width="900" height="327" alt="Screenshot 2025-07-10 at 20 12 38"
src="https://github.com/user-attachments/assets/044f00ed-7e05-46b7-a821-ce1cb0ee9298"
/>
<br /><br />

Navbar updated to look pretty from the new designs:
- the logo is now centred instead of on the left
- menu items have been updated to a smaller font-size and less radius
- icons have been updated

I also generated the API files ( _sorry for the noise_ ). I had to do
some border-radius and button updates on the atoms/tokens for it to look
good.

## Checklist 📋

## For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login/logout
  - [x] The new navbar looks good across screens 

## For configuration changes

No config changes
2025-07-10 18:06:12 +00:00
Ubbe
22c76eab61 feat(toast): update styles (#10339)
## Changes 🏗️

Style refinements on Toasts.

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Check Storybook toast stories
  - [x] They match Figma 

#### For configuration changes:

None
2025-07-10 15:04:14 +00:00
Swifty
7688a9701e perf(backend/db): Optimize StoreAgent and Creator views with database indexes and materialized views (#10084)
### Summary
Performance optimization for the platform's store and creator
functionality by adding targeted database indexes and implementing
materialized views to reduce query execution time.

### Changes 🏗️

**Database Performance Optimizations:**
- Added strategic database indexes for `StoreListing`,
`StoreListingVersion`, `StoreListingReview`, `AgentGraphExecution`, and
`Profile` tables
- Implemented materialized views (`mv_agent_run_counts`,
`mv_review_stats`) to cache expensive aggregation queries
- Optimized `StoreAgent` and `Creator` views to use materialized views
and improved query patterns
- Added automated refresh function with 15-minute scheduling for
materialized views (when pg_cron extension is available)

**Key Performance Improvements:**
- Filtered indexes on approved store listings to speed up marketplace
queries
- GIN index on categories for faster category-based searches
- Composite indexes for common query patterns (e.g., listing + version
lookups)
- Pre-computed agent run counts and review statistics to eliminate
expensive aggregations

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified migration runs successfully without errors
  - [x] Confirmed materialized views are created and populated correctly
- [x] Tested StoreAgent and Creator view queries return expected results
  - [x] Validated automatic refresh function works properly
  - [x] Confirmed rollback migration successfully removes all changes

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

**Note:** No configuration changes were required as this is purely a
database schema optimization.
2025-07-10 14:57:55 +00:00
Lluis Agusti
8ae37491e4 Merge 'dev' into 'feat/agent-notifications' 2025-07-10 18:41:12 +04:00
Swifty
243400e128 feat(platform): Add Block Development SDK with auto-registration system (#10074)
## Block Development SDK - Simplifying Block Creation

### Problem
Currently, creating a new block requires manual updates to **5+ files**
scattered across the codebase:
- `backend/data/block_cost_config.py` - Manually add block costs
- `backend/integrations/credentials_store.py` - Add default credentials
- `backend/integrations/providers.py` - Register new providers
- `backend/integrations/oauth/__init__.py` - Register OAuth handlers
- `backend/integrations/webhooks/__init__.py` - Register webhook
managers

This creates significant friction for developers, increases the chance
of configuration errors, and makes the platform difficult to scale.

### Solution
This PR introduces a **Block Development SDK** that provides:
- Single import for all block development needs: `from backend.sdk
import *`
- Automatic registration of all block configurations
- Zero external file modifications required
- Provider-based configuration with inheritance

### Changes 🏗️

#### 1. **New SDK Module** (`backend/sdk/`)
- **`__init__.py`**: Unified exports of 68+ block development components
- **`registry.py`**: Central auto-registration system for all block
configurations
- **`builder.py`**: `ProviderBuilder` class for fluent provider
configuration
- **`provider.py`**: Provider configuration management
- **`cost_integration.py`**: Automatic cost application system

#### 2. **Provider Builder Pattern**
```python
# Configure once, use everywhere
my_provider = (
    ProviderBuilder("my-service")
    .with_api_key("MY_SERVICE_API_KEY", "My Service API Key")
    .with_base_cost(5, BlockCostType.RUN)
    .build()
)
```

#### 3. **Automatic Cost System**
- Provider base costs automatically applied to all blocks using that
provider
- Override with `@cost` decorator for block-specific pricing
- Tiered pricing support with cost filters

#### 4. **Dynamic Provider Support**
- Modified `ProviderName` enum to accept any string via `_missing_`
method
- No more manual enum updates for new providers

#### 5. **Application Integration**
- Added `sync_all_provider_costs()` to `initialize_blocks()` for
automatic cost registration
- Maintains full backward compatibility with existing blocks

#### 6. **Comprehensive Examples** (`backend/blocks/examples/`)
- `simple_example_block.py` - Basic block structure
- `example_sdk_block.py` - Provider with credentials
- `cost_example_block.py` - Various cost patterns
- `advanced_provider_example.py` - Custom API clients
- `example_webhook_sdk_block.py` - Webhook configuration

#### 7. **Extensive Testing**
- 6 new test modules with 30+ test cases
- Integration tests for all SDK features
- Cost calculation verification
- Provider registration tests

### Before vs After

**Before SDK:**
```python
# 1. Multiple complex imports
from backend.data.block import Block, BlockCategory, BlockOutput
from backend.data.model import SchemaField, CredentialsField
# ... many more imports

# 2. Update block_cost_config.py
BLOCK_COSTS[MyBlock] = [BlockCost(...)]

# 3. Update credentials_store.py
DEFAULT_CREDENTIALS.append(...)

# 4. Update providers.py enum
# 5. Update oauth/__init__.py
# 6. Update webhooks/__init__.py
```

**After SDK:**
```python
from backend.sdk import *

# Everything configured in one place
my_provider = (
    ProviderBuilder("my-service")
    .with_api_key("MY_API_KEY", "My API Key")
    .with_base_cost(10, BlockCostType.RUN)
    .build()
)

class MyBlock(Block):
    class Input(BlockSchema):
        credentials: CredentialsMetaInput = my_provider.credentials_field()
        data: String = SchemaField(description="Input data")
    
    class Output(BlockSchema):
        result: String = SchemaField(description="Result")
    
    # That's it\! No external files to modify
```

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Created new blocks using SDK pattern with provider configuration
  - [x] Verified automatic cost registration for provider-based blocks
  - [x] Tested cost override with @cost decorator
  - [x] Confirmed custom providers work without enum modifications
  - [x] Verified all example blocks execute correctly
  - [x] Tested backward compatibility with existing blocks
  - [x] Ran all SDK tests (30+ tests, all passing)
  - [x] Created blocks with credentials and verified authentication
  - [x] Tested webhook block configuration
  - [x] Verified application startup with auto-registration

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
(no changes needed)
- [x] `docker-compose.yml` is updated or already compatible with my
changes (no changes needed)
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

### Impact

- **Developer Experience**: Block creation time reduced from hours to
minutes
- **Maintainability**: All block configuration in one place
- **Scalability**: Support hundreds of blocks without enum updates
- **Type Safety**: Full IDE support with proper type hints
- **Testing**: Easier to test blocks in isolation

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2025-07-10 16:17:55 +02:00
Reinier van der Leer
c77cb1fcfb fix(backend/library): Fix sub_graphs check in LibraryAgent.from_db(..) (#10316)
- Follow-up fix for #10301

The condition that determines whether
`LibraryAgent.credentials_input_schema` is set incorrectly handles empty
lists of sub-graphs.

### Changes 🏗️

- Check if `sub_graphs is not None` rather than using the boolean
interpretation of `sub_graphs`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Trivial change, no test needed.
2025-07-10 07:48:18 +00:00
Ubbe
b3b5eefe2c feat(frontend): change to use Sonner toast (#10334)
## Changes 🏗️

Makes changes to use [Sonner for Toasts](https://sonner.emilkowal.ski/)
rather than the [Radix UI
primitive](https://www.radix-ui.com/primitives/docs/components/toast).

<img width="431" alt="Screenshot 2025-07-09 at 15 49 47"
src="https://github.com/user-attachments/assets/c09c3c1e-fd80-44d2-9336-c955c2d4f288"
/>
<img width="444" alt="Screenshot 2025-07-09 at 15 51 05"
src="https://github.com/user-attachments/assets/cc2a3491-7b76-44e2-8bec-3ad0ac917148"
/>
<img width="450" alt="Screenshot 2025-07-09 at 15 51 50"
src="https://github.com/user-attachments/assets/e8ede05d-3488-43f4-aa43-7d3cba92a050"
/>


https://github.com/user-attachments/assets/deb4ce1c-13bb-4f69-890e-9b8680c848e7

<img width="500" alt="Screenshot 2025-07-09 at 15 59 09"
src="https://github.com/user-attachments/assets/5636969d-4c9a-41e6-acd1-afa49b8e70c6"
/>

Sonner is [the one used in
shadcn](https://ui.shadcn.com/docs/components/toast) nowadays, because
it brings great UX on touch devices:
- allows to swipe to dismiss
- they can stack nicely if multiple toasts appear ( see video 📹 )
- when stack, hovering over them reveals them all nicely ( see video 📹 )

I kept the existing `useToast()` API used on the pages, so I had to only
refactor the hook not the calls 🏁

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Login
  - [x] Click around the app and trigger toasts
  - [x] Toasts look good 

### For configuration changes

Nope
2025-07-09 17:09:16 +00:00
Lluis Agusti
f45e5e0d59 chore: prettier 2025-07-09 14:57:33 +04:00
Lluis Agusti
1231236d87 chore: lock 2025-07-09 14:46:43 +04:00
Lluis Agusti
4db0792ade Merge 'dev' into 'feat/agent-notifications' 2025-07-09 14:45:46 +04:00
Lluis Agusti
81cb6fb1e6 chore: fixes... 2025-07-08 19:45:09 +04:00
Lluis Agusti
c16598eed6 Merge 'dev' into 'feat/agent-notifications' 2025-07-08 19:32:04 +04:00
Lluis Agusti
7706740308 chore: agent notifications 2025-07-07 13:36:43 +04:00
237 changed files with 13971 additions and 4724 deletions

View File

@@ -148,6 +148,7 @@ jobs:
onlyChanged: true
workingDir: autogpt_platform/frontend
token: ${{ secrets.GITHUB_TOKEN }}
exitOnceUploaded: true
test:
runs-on: ubuntu-latest

View File

@@ -1,8 +1,7 @@
# AutoGPT: Build, Deploy, and Run AI Agents
[![Discord Follow](https://dcbadge.vercel.app/api/server/autogpt?style=flat)](https://discord.gg/autogpt) &ensp;
[![Discord Follow](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fautogpt%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&label=total%20members&logo=discord&logoColor=white&color=7289da)](https://discord.gg/autogpt) &ensp;
[![Twitter Follow](https://img.shields.io/twitter/follow/Auto_GPT?style=social)](https://twitter.com/Auto_GPT) &ensp;
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
**AutoGPT** is a powerful platform that allows you to create, deploy, and manage continuous AI agents that automate complex workflows.

View File

@@ -1,7 +1,6 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Repository Overview
AutoGPT Platform is a monorepo containing:
@@ -144,4 +143,4 @@ Key models (defined in `/backend/schema.prisma`):
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
- Applied to both main API server and external API applications
- Applied to both main API server and external API applications

View File

@@ -199,9 +199,18 @@ ZEROBOUNCE_API_KEY=
## ===== OPTIONAL API KEYS END ===== ##
# Block Error Rate Monitoring
BLOCK_ERROR_RATE_THRESHOLD=0.5
BLOCK_ERROR_RATE_CHECK_INTERVAL_SECS=86400
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false
ENABLE_FILE_LOGGING=false
# Use to manually set the log directory
# LOG_DIR=./logs
# Example Blocks Configuration
# Set to true to enable example blocks in development
# These blocks are disabled by default in production
ENABLE_EXAMPLE_BLOCKS=false

View File

@@ -0,0 +1,150 @@
# Test Data Scripts
This directory contains scripts for creating and updating test data in the AutoGPT Platform database, specifically designed to test the materialized views for the store functionality.
## Scripts
### test_data_creator.py
Creates a comprehensive set of test data including:
- Users with profiles
- Agent graphs, nodes, and executions
- Store listings with multiple versions
- Reviews and ratings
- Library agents
- Integration webhooks
- Onboarding data
- Credit transactions
**Image/Video Domains Used:**
- Images: `picsum.photos` (for all image URLs)
- Videos: `youtube.com` (for store listing videos)
### test_data_updater.py
Updates existing test data to simulate real-world changes:
- Adds new agent graph executions
- Creates new store listing reviews
- Updates store listing versions
- Adds credit transactions
- Refreshes materialized views
### check_db.py
Tests and verifies materialized views functionality:
- Checks pg_cron job status (for automatic refresh)
- Displays current materialized view counts
- Adds test data (executions and reviews)
- Creates store listings if none exist
- Manually refreshes materialized views
- Compares before/after counts to verify updates
- Provides a summary of test results
## Materialized Views
The scripts test three key database views:
1. **mv_agent_run_counts**: Tracks execution counts by agent
2. **mv_review_stats**: Tracks review statistics (count, average rating) by store listing
3. **StoreAgent**: A view that combines store listing data with execution counts and ratings for display
The materialized views (mv_agent_run_counts and mv_review_stats) are automatically refreshed every 15 minutes via pg_cron, or can be manually refreshed using the `refresh_store_materialized_views()` function.
## Usage
### Prerequisites
1. Ensure the database is running:
```bash
docker compose up -d
# or for test database:
docker compose -f docker-compose.test.yaml --env-file ../.env up -d
```
2. Run database migrations:
```bash
poetry run prisma migrate deploy
```
### Running the Scripts
#### Option 1: Use the helper script (from backend directory)
```bash
poetry run python run_test_data.py
```
#### Option 2: Run individually
```bash
# From backend/test directory:
# Create initial test data
poetry run python test_data_creator.py
# Update data to test materialized view changes
poetry run python test_data_updater.py
# From backend directory:
# Test materialized views functionality
poetry run python check_db.py
# Check store data status
poetry run python check_store_data.py
```
#### Option 3: Use the shell script (from backend directory)
```bash
./run_test_data_scripts.sh
```
### Manual Materialized View Refresh
To manually refresh the materialized views:
```sql
SELECT refresh_store_materialized_views();
```
## Configuration
The scripts use the database configuration from your `.env` file:
- `DATABASE_URL`: PostgreSQL connection string
- Database should have the platform schema
## Data Generation Limits
Configured in `test_data_creator.py`:
- 100 users
- 100 agent blocks
- 1-5 graphs per user
- 2-5 nodes per graph
- 1-5 presets per user
- 1-10 library agents per user
- 1-20 executions per graph
- 1-5 reviews per store listing version
## Notes
- All image URLs use `picsum.photos` for consistency with Next.js image configuration
- The scripts create realistic relationships between entities
- Materialized views are refreshed at the end of each script
- Data is designed to test both happy paths and edge cases
## Troubleshooting
### Reviews and StoreAgent view showing 0
If `check_db.py` shows that reviews remain at 0 and StoreAgent view shows 0 store agents:
1. **No store listings exist**: The script will automatically create test store listings if none exist
2. **No approved versions**: Store listings need approved versions to appear in the StoreAgent view
3. **Check with `check_store_data.py`**: This script provides detailed information about:
- Total store listings
- Store listing versions by status
- Existing reviews
- StoreAgent view contents
- Agent graph executions
### pg_cron not installed
The warning "pg_cron extension is not installed" is normal in local development environments. The materialized views can still be refreshed manually using the `refresh_store_materialized_views()` function, which all scripts do automatically.
### Common Issues
- **Type errors with None values**: Fixed in the latest version of check_db.py by using `or 0` for nullable numeric fields
- **Missing relations**: Ensure you're using the correct field names (e.g., `StoreListing` not `storeListing` in includes)
- **Column name mismatches**: The database uses camelCase for column names (e.g., `agentGraphId` not `agent_graph_id`)

View File

@@ -14,14 +14,27 @@ T = TypeVar("T")
@functools.cache
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config
# Check if example blocks should be loaded from settings
config = Config()
load_examples = config.enable_example_blocks
# Dynamically load all modules under backend.blocks
current_dir = Path(__file__).parent
modules = [
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
for f in current_dir.rglob("*.py")
if f.is_file() and f.name != "__init__.py" and not f.name.startswith("test_")
]
modules = []
for f in current_dir.rglob("*.py"):
if not f.is_file() or f.name == "__init__.py" or f.name.startswith("test_"):
continue
# Skip examples directory if not enabled
relative_path = f.relative_to(current_dir)
if not load_examples and relative_path.parts[0] == "examples":
continue
module_path = str(relative_path)[:-3].replace(os.path.sep, ".")
modules.append(module_path)
for module in modules:
if not re.match("^[a-z0-9_.]+$", module):
raise ValueError(

View File

@@ -14,7 +14,7 @@ from backend.data.block import (
get_block,
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json, retry
_logger = logging.getLogger(__name__)
@@ -151,6 +151,12 @@ class AgentExecutorBlock(Block):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
# If the graph execution is COMPLETED, TERMINATED, or FAILED,
# we can stop listening for further events.
self.merge_stats(
NodeExecutionStats(
extra_cost=event.stats.cost if event.stats else 0,
extra_steps=event.stats.node_exec_count if event.stats else 0,
)
)
break
logger.debug(

View File

@@ -1,32 +0,0 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ExaCredentials = APIKeyCredentials
ExaCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.EXA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="exa",
api_key=SecretStr("mock-exa-api-key"),
title="Mock Exa API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def ExaCredentialsField() -> ExaCredentialsInput:
"""Creates an Exa credentials input on a block."""
return CredentialsField(description="The Exa integration requires an API Key.")

View File

@@ -0,0 +1,16 @@
"""
Shared configuration for all Exa blocks using the new SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
from ._webhook import ExaWebhookManager
# Configure the Exa provider once for all blocks
exa = (
ProviderBuilder("exa")
.with_api_key("EXA_API_KEY", "Exa API Key")
.with_webhook_manager(ExaWebhookManager)
.with_base_cost(1, BlockCostType.RUN)
.build()
)

View File

@@ -0,0 +1,134 @@
"""
Exa Webhook Manager implementation.
"""
import hashlib
import hmac
from enum import Enum
from backend.data.model import Credentials
from backend.sdk import (
APIKeyCredentials,
BaseWebhooksManager,
ProviderName,
Requests,
Webhook,
)
class ExaWebhookType(str, Enum):
"""Available webhook types for Exa."""
WEBSET = "webset"
class ExaEventType(str, Enum):
"""Available event types for Exa webhooks."""
WEBSET_CREATED = "webset.created"
WEBSET_DELETED = "webset.deleted"
WEBSET_PAUSED = "webset.paused"
WEBSET_IDLE = "webset.idle"
WEBSET_SEARCH_CREATED = "webset.search.created"
WEBSET_SEARCH_CANCELED = "webset.search.canceled"
WEBSET_SEARCH_COMPLETED = "webset.search.completed"
WEBSET_SEARCH_UPDATED = "webset.search.updated"
IMPORT_CREATED = "import.created"
IMPORT_COMPLETED = "import.completed"
IMPORT_PROCESSING = "import.processing"
WEBSET_ITEM_CREATED = "webset.item.created"
WEBSET_ITEM_ENRICHED = "webset.item.enriched"
WEBSET_EXPORT_CREATED = "webset.export.created"
WEBSET_EXPORT_COMPLETED = "webset.export.completed"
class ExaWebhookManager(BaseWebhooksManager):
"""Webhook manager for Exa API."""
PROVIDER_NAME = ProviderName("exa")
class WebhookType(str, Enum):
WEBSET = "webset"
@classmethod
async def validate_payload(cls, webhook: Webhook, request) -> tuple[dict, str]:
"""Validate incoming webhook payload and signature."""
payload = await request.json()
# Get event type from payload
event_type = payload.get("eventType", "unknown")
# Verify webhook signature if secret is available
if webhook.secret:
signature = request.headers.get("X-Exa-Signature")
if signature:
# Compute expected signature
body = await request.body()
expected_signature = hmac.new(
webhook.secret.encode(), body, hashlib.sha256
).hexdigest()
# Compare signatures
if not hmac.compare_digest(signature, expected_signature):
raise ValueError("Invalid webhook signature")
return payload, event_type
async def _register_webhook(
self,
credentials: Credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with Exa API."""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("Exa webhooks require API key credentials")
api_key = credentials.api_key.get_secret_value()
# Create webhook via Exa API
response = await Requests().post(
"https://api.exa.ai/v0/webhooks",
headers={"x-api-key": api_key},
json={
"url": ingress_url,
"events": events,
"metadata": {
"resource": resource,
"webhook_type": webhook_type,
},
},
)
if not response.ok:
error_data = response.json()
raise Exception(f"Failed to create Exa webhook: {error_data}")
webhook_data = response.json()
# Store the secret returned by Exa
return webhook_data["id"], {
"events": events,
"resource": resource,
"exa_secret": webhook_data.get("secret"),
}
async def _deregister_webhook(
self, webhook: Webhook, credentials: Credentials
) -> None:
"""Deregister webhook from Exa API."""
if not isinstance(credentials, APIKeyCredentials):
raise ValueError("Exa webhooks require API key credentials")
api_key = credentials.api_key.get_secret_value()
# Delete webhook via Exa API
response = await Requests().delete(
f"https://api.exa.ai/v0/webhooks/{webhook.provider_webhook_id}",
headers={"x-api-key": api_key},
)
if not response.ok and response.status != 404:
error_data = response.json()
raise Exception(f"Failed to delete Exa webhook: {error_data}")

View File

@@ -0,0 +1,124 @@
from backend.sdk import (
APIKeyCredentials,
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._config import exa
class CostBreakdown(BaseModel):
keywordSearch: float
neuralSearch: float
contentText: float
contentHighlight: float
contentSummary: float
class SearchBreakdown(BaseModel):
search: float
contents: float
breakdown: CostBreakdown
class PerRequestPrices(BaseModel):
neuralSearch_1_25_results: float
neuralSearch_26_100_results: float
neuralSearch_100_plus_results: float
keywordSearch_1_100_results: float
keywordSearch_100_plus_results: float
class PerPagePrices(BaseModel):
contentText: float
contentHighlight: float
contentSummary: float
class CostDollars(BaseModel):
total: float
breakDown: list[SearchBreakdown]
perRequestPrices: PerRequestPrices
perPagePrices: PerPagePrices
class ExaAnswerBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
query: str = SchemaField(
description="The question or query to answer",
placeholder="What is the latest valuation of SpaceX?",
)
text: bool = SchemaField(
default=False,
description="If true, the response includes full text content in the search results",
advanced=True,
)
model: str = SchemaField(
default="exa",
description="The search model to use (exa or exa-pro)",
placeholder="exa",
advanced=True,
)
class Output(BlockSchema):
answer: str = SchemaField(
description="The generated answer based on search results"
)
citations: list[dict] = SchemaField(
description="Search results used to generate the answer",
default_factory=list,
)
cost_dollars: CostDollars = SchemaField(
description="Cost breakdown of the request"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="b79ca4cc-9d5e-47d1-9d4f-e3a2d7f28df5",
description="Get an LLM answer to a question informed by Exa search results",
categories={BlockCategory.SEARCH, BlockCategory.AI},
input_schema=ExaAnswerBlock.Input,
output_schema=ExaAnswerBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/answer"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Build the payload
payload = {
"query": input_data.query,
"text": input_data.text,
"model": input_data.model,
}
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
yield "answer", data.get("answer", "")
yield "citations", data.get("citations", [])
yield "cost_dollars", data.get("costDollars", {})
except Exception as e:
yield "error", str(e)
yield "answer", ""
yield "citations", []
yield "cost_dollars", {}

View File

@@ -1,57 +1,39 @@
from typing import List
from pydantic import BaseModel
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
class ContentRetrievalSettings(BaseModel):
text: dict = SchemaField(
description="Text content settings",
default={"maxCharacters": 1000, "includeHtmlTags": False},
advanced=True,
)
highlights: dict = SchemaField(
description="Highlight settings",
default={
"numSentences": 3,
"highlightsPerUrl": 3,
"query": "",
},
advanced=True,
)
summary: dict = SchemaField(
description="Summary settings",
default={"query": ""},
advanced=True,
)
from ._config import exa
from .helpers import ContentSettings
class ExaContentsBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
ids: List[str] = SchemaField(
description="Array of document IDs obtained from searches",
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
contents: ContentRetrievalSettings = SchemaField(
ids: list[str] = SchemaField(
description="Array of document IDs obtained from searches"
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
default=ContentRetrievalSettings(),
default=ContentSettings(),
advanced=True,
)
class Output(BlockSchema):
results: list = SchemaField(
description="List of document contents",
default_factory=list,
description="List of document contents", default_factory=list
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
@@ -63,7 +45,7 @@ class ExaContentsBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/contents"
headers = {
@@ -71,6 +53,7 @@ class ExaContentsBlock(Block):
"x-api-key": credentials.api_key.get_secret_value(),
}
# Convert ContentSettings to API format
payload = {
"ids": input_data.ids,
"text": input_data.contents.text,

View File

@@ -1,8 +1,6 @@
from typing import Optional
from pydantic import BaseModel
from backend.data.model import SchemaField
from backend.sdk import BaseModel, SchemaField
class TextSettings(BaseModel):
@@ -42,13 +40,90 @@ class SummarySettings(BaseModel):
class ContentSettings(BaseModel):
text: TextSettings = SchemaField(
default=TextSettings(),
description="Text content settings",
)
highlights: HighlightSettings = SchemaField(
default=HighlightSettings(),
description="Highlight settings",
)
summary: SummarySettings = SchemaField(
default=SummarySettings(),
description="Summary settings",
)
# Websets Models
class WebsetEntitySettings(BaseModel):
type: Optional[str] = SchemaField(
default=None,
description="Entity type (e.g., 'company', 'person')",
placeholder="company",
)
class WebsetCriterion(BaseModel):
description: str = SchemaField(
description="Description of the criterion",
placeholder="Must be based in the US",
)
success_rate: Optional[int] = SchemaField(
default=None,
description="Success rate percentage",
ge=0,
le=100,
)
class WebsetSearchConfig(BaseModel):
query: str = SchemaField(
description="Search query",
placeholder="Marketing agencies based in the US",
)
count: int = SchemaField(
default=10,
description="Number of results to return",
ge=1,
le=100,
)
entity: Optional[WebsetEntitySettings] = SchemaField(
default=None,
description="Entity settings for the search",
)
criteria: Optional[list[WebsetCriterion]] = SchemaField(
default=None,
description="Search criteria",
)
behavior: Optional[str] = SchemaField(
default="override",
description="Behavior when updating results ('override' or 'append')",
placeholder="override",
)
class EnrichmentOption(BaseModel):
label: str = SchemaField(
description="Label for the enrichment option",
placeholder="Option 1",
)
class WebsetEnrichmentConfig(BaseModel):
title: str = SchemaField(
description="Title of the enrichment",
placeholder="Company Details",
)
description: str = SchemaField(
description="Description of what this enrichment does",
placeholder="Extract company information",
)
format: str = SchemaField(
default="text",
description="Format of the enrichment result",
placeholder="text",
)
instructions: Optional[str] = SchemaField(
default=None,
description="Instructions for the enrichment",
placeholder="Extract key company metrics",
)
options: Optional[list[EnrichmentOption]] = SchemaField(
default=None,
description="Options for the enrichment",
)

View File

@@ -1,71 +1,61 @@
from datetime import datetime
from typing import List
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from backend.blocks.exa.helpers import ContentSettings
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
from ._config import exa
from .helpers import ContentSettings
class ExaSearchBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
query: str = SchemaField(description="The search query")
use_auto_prompt: bool = SchemaField(
description="Whether to use autoprompt",
default=True,
advanced=True,
)
type: str = SchemaField(
description="Type of search",
default="",
advanced=True,
description="Whether to use autoprompt", default=True, advanced=True
)
type: str = SchemaField(description="Type of search", default="", advanced=True)
category: str = SchemaField(
description="Category to search within",
default="",
advanced=True,
description="Category to search within", default="", advanced=True
)
number_of_results: int = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
description="Number of results to return", default=10, advanced=True
)
include_domains: List[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
include_domains: list[str] = SchemaField(
description="Domains to include in search", default_factory=list
)
exclude_domains: List[str] = SchemaField(
exclude_domains: list[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
description="Start date for crawled content"
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
description="End date for crawled content"
)
start_published_date: datetime = SchemaField(
description="Start date for published content",
description="Start date for published content"
)
end_published_date: datetime = SchemaField(
description="End date for published content",
description="End date for published content"
)
include_text: List[str] = SchemaField(
description="Text patterns to include",
default_factory=list,
advanced=True,
include_text: list[str] = SchemaField(
description="Text patterns to include", default_factory=list, advanced=True
)
exclude_text: List[str] = SchemaField(
description="Text patterns to exclude",
default_factory=list,
advanced=True,
exclude_text: list[str] = SchemaField(
description="Text patterns to exclude", default_factory=list, advanced=True
)
contents: ContentSettings = SchemaField(
description="Content retrieval settings",
@@ -75,8 +65,7 @@ class ExaSearchBlock(Block):
class Output(BlockSchema):
results: list = SchemaField(
description="List of search results",
default_factory=list,
description="List of search results", default_factory=list
)
error: str = SchemaField(
description="Error message if the request failed",
@@ -92,7 +81,7 @@ class ExaSearchBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/search"
headers = {
@@ -104,7 +93,7 @@ class ExaSearchBlock(Block):
"query": input_data.query,
"useAutoprompt": input_data.use_auto_prompt,
"numResults": input_data.number_of_results,
"contents": input_data.contents.dict(),
"contents": input_data.contents.model_dump(),
}
date_field_mapping = {

View File

@@ -1,57 +1,60 @@
from datetime import datetime
from typing import Any, List
from typing import Any
from backend.blocks.exa._auth import (
ExaCredentials,
ExaCredentialsField,
ExaCredentialsInput,
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
from ._config import exa
from .helpers import ContentSettings
class ExaFindSimilarBlock(Block):
class Input(BlockSchema):
credentials: ExaCredentialsInput = ExaCredentialsField()
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
url: str = SchemaField(
description="The url for which you would like to find similar links"
)
number_of_results: int = SchemaField(
description="Number of results to return",
default=10,
advanced=True,
description="Number of results to return", default=10, advanced=True
)
include_domains: List[str] = SchemaField(
include_domains: list[str] = SchemaField(
description="Domains to include in search",
default_factory=list,
advanced=True,
)
exclude_domains: List[str] = SchemaField(
exclude_domains: list[str] = SchemaField(
description="Domains to exclude from search",
default_factory=list,
advanced=True,
)
start_crawl_date: datetime = SchemaField(
description="Start date for crawled content",
description="Start date for crawled content"
)
end_crawl_date: datetime = SchemaField(
description="End date for crawled content",
description="End date for crawled content"
)
start_published_date: datetime = SchemaField(
description="Start date for published content",
description="Start date for published content"
)
end_published_date: datetime = SchemaField(
description="End date for published content",
description="End date for published content"
)
include_text: List[str] = SchemaField(
include_text: list[str] = SchemaField(
description="Text patterns to include (max 1 string, up to 5 words)",
default_factory=list,
advanced=True,
)
exclude_text: List[str] = SchemaField(
exclude_text: list[str] = SchemaField(
description="Text patterns to exclude (max 1 string, up to 5 words)",
default_factory=list,
advanced=True,
@@ -63,11 +66,13 @@ class ExaFindSimilarBlock(Block):
)
class Output(BlockSchema):
results: List[Any] = SchemaField(
results: list[Any] = SchemaField(
description="List of similar documents with title, URL, published date, author, and score",
default_factory=list,
)
error: str = SchemaField(description="Error message if the request failed")
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
@@ -79,7 +84,7 @@ class ExaFindSimilarBlock(Block):
)
async def run(
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/findSimilar"
headers = {
@@ -90,7 +95,7 @@ class ExaFindSimilarBlock(Block):
payload = {
"url": input_data.url,
"numResults": input_data.number_of_results,
"contents": input_data.contents.dict(),
"contents": input_data.contents.model_dump(),
}
optional_field_mapping = {

View File

@@ -0,0 +1,201 @@
"""
Exa Webhook Blocks
These blocks handle webhook events from Exa's API for websets and other events.
"""
from backend.sdk import (
BaseModel,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
CredentialsMetaInput,
Field,
ProviderName,
SchemaField,
)
from ._config import exa
from ._webhook import ExaEventType
class WebsetEventFilter(BaseModel):
"""Filter configuration for Exa webset events."""
webset_created: bool = Field(
default=True, description="Receive notifications when websets are created"
)
webset_deleted: bool = Field(
default=False, description="Receive notifications when websets are deleted"
)
webset_paused: bool = Field(
default=False, description="Receive notifications when websets are paused"
)
webset_idle: bool = Field(
default=False, description="Receive notifications when websets become idle"
)
search_created: bool = Field(
default=True,
description="Receive notifications when webset searches are created",
)
search_completed: bool = Field(
default=True, description="Receive notifications when webset searches complete"
)
search_canceled: bool = Field(
default=False,
description="Receive notifications when webset searches are canceled",
)
search_updated: bool = Field(
default=False,
description="Receive notifications when webset searches are updated",
)
item_created: bool = Field(
default=True, description="Receive notifications when webset items are created"
)
item_enriched: bool = Field(
default=True, description="Receive notifications when webset items are enriched"
)
export_created: bool = Field(
default=False,
description="Receive notifications when webset exports are created",
)
export_completed: bool = Field(
default=True, description="Receive notifications when webset exports complete"
)
import_created: bool = Field(
default=False, description="Receive notifications when imports are created"
)
import_completed: bool = Field(
default=True, description="Receive notifications when imports complete"
)
import_processing: bool = Field(
default=False, description="Receive notifications when imports are processing"
)
class ExaWebsetWebhookBlock(Block):
"""
Receives webhook notifications for Exa webset events.
This block allows you to monitor various events related to Exa websets,
including creation, updates, searches, and exports.
"""
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="Exa API credentials for webhook management"
)
webhook_url: str = SchemaField(
description="URL to receive webhooks (auto-generated)",
default="",
hidden=True,
)
webset_id: str = SchemaField(
description="The webset ID to monitor (optional, monitors all if empty)",
default="",
)
event_filter: WebsetEventFilter = SchemaField(
description="Configure which events to receive", default=WebsetEventFilter()
)
payload: dict = SchemaField(
description="Webhook payload data", default={}, hidden=True
)
class Output(BlockSchema):
event_type: str = SchemaField(description="Type of event that occurred")
event_id: str = SchemaField(description="Unique identifier for this event")
webset_id: str = SchemaField(description="ID of the affected webset")
data: dict = SchemaField(description="Event-specific data")
timestamp: str = SchemaField(description="When the event occurred")
metadata: dict = SchemaField(description="Additional event metadata")
def __init__(self):
super().__init__(
id="d0204ed8-8b81-408d-8b8d-ed087a546228",
description="Receive webhook notifications for Exa webset events",
categories={BlockCategory.INPUT},
input_schema=ExaWebsetWebhookBlock.Input,
output_schema=ExaWebsetWebhookBlock.Output,
block_type=BlockType.WEBHOOK,
webhook_config=BlockWebhookConfig(
provider=ProviderName("exa"),
webhook_type="webset",
event_filter_input="event_filter",
resource_format="{webset_id}",
),
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""Process incoming Exa webhook payload."""
try:
payload = input_data.payload
# Extract event details
event_type = payload.get("eventType", "unknown")
event_id = payload.get("eventId", "")
# Get webset ID from payload or input
webset_id = payload.get("websetId", input_data.webset_id)
# Check if we should process this event based on filter
should_process = self._should_process_event(
event_type, input_data.event_filter
)
if not should_process:
# Skip events that don't match our filter
return
# Extract event data
event_data = payload.get("data", {})
timestamp = payload.get("occurredAt", payload.get("createdAt", ""))
metadata = payload.get("metadata", {})
yield "event_type", event_type
yield "event_id", event_id
yield "webset_id", webset_id
yield "data", event_data
yield "timestamp", timestamp
yield "metadata", metadata
except Exception as e:
# Handle errors gracefully
yield "event_type", "error"
yield "event_id", ""
yield "webset_id", input_data.webset_id
yield "data", {"error": str(e)}
yield "timestamp", ""
yield "metadata", {}
def _should_process_event(
self, event_type: str, event_filter: WebsetEventFilter
) -> bool:
"""Check if an event should be processed based on the filter."""
filter_mapping = {
ExaEventType.WEBSET_CREATED: event_filter.webset_created,
ExaEventType.WEBSET_DELETED: event_filter.webset_deleted,
ExaEventType.WEBSET_PAUSED: event_filter.webset_paused,
ExaEventType.WEBSET_IDLE: event_filter.webset_idle,
ExaEventType.WEBSET_SEARCH_CREATED: event_filter.search_created,
ExaEventType.WEBSET_SEARCH_COMPLETED: event_filter.search_completed,
ExaEventType.WEBSET_SEARCH_CANCELED: event_filter.search_canceled,
ExaEventType.WEBSET_SEARCH_UPDATED: event_filter.search_updated,
ExaEventType.WEBSET_ITEM_CREATED: event_filter.item_created,
ExaEventType.WEBSET_ITEM_ENRICHED: event_filter.item_enriched,
ExaEventType.WEBSET_EXPORT_CREATED: event_filter.export_created,
ExaEventType.WEBSET_EXPORT_COMPLETED: event_filter.export_completed,
ExaEventType.IMPORT_CREATED: event_filter.import_created,
ExaEventType.IMPORT_COMPLETED: event_filter.import_completed,
ExaEventType.IMPORT_PROCESSING: event_filter.import_processing,
}
# Try to convert string to ExaEventType enum
try:
event_type_enum = ExaEventType(event_type)
return filter_mapping.get(event_type_enum, True)
except ValueError:
# If event_type is not a valid enum value, process it by default
return True

View File

@@ -0,0 +1,456 @@
from typing import Any, Optional
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
Requests,
SchemaField,
)
from ._config import exa
from .helpers import WebsetEnrichmentConfig, WebsetSearchConfig
class ExaCreateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
search: WebsetSearchConfig = SchemaField(
description="Initial search configuration for the Webset"
)
enrichments: Optional[list[WebsetEnrichmentConfig]] = SchemaField(
default=None,
description="Enrichments to apply to Webset items",
advanced=True,
)
external_id: Optional[str] = SchemaField(
default=None,
description="External identifier for the webset",
placeholder="my-webset-123",
advanced=True,
)
metadata: Optional[dict] = SchemaField(
default=None,
description="Key-value pairs to associate with this webset",
advanced=True,
)
class Output(BlockSchema):
webset_id: str = SchemaField(
description="The unique identifier for the created webset"
)
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="0cda29ff-c549-4a19-8805-c982b7d4ec34",
description="Create a new Exa Webset for persistent web search collections",
categories={BlockCategory.SEARCH},
input_schema=ExaCreateWebsetBlock.Input,
output_schema=ExaCreateWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Build the payload
payload: dict[str, Any] = {
"search": input_data.search.model_dump(exclude_none=True),
}
# Convert enrichments to API format
if input_data.enrichments:
enrichments_data = []
for enrichment in input_data.enrichments:
enrichments_data.append(enrichment.model_dump(exclude_none=True))
payload["enrichments"] = enrichments_data
if input_data.external_id:
payload["externalId"] = input_data.external_id
if input_data.metadata:
payload["metadata"] = input_data.metadata
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "created_at", data.get("createdAt", "")
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "created_at", ""
class ExaUpdateWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to update",
placeholder="webset-id-or-external-id",
)
metadata: Optional[dict] = SchemaField(
default=None,
description="Key-value pairs to associate with this webset (set to null to clear)",
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
metadata: dict = SchemaField(
description="Updated metadata for the webset", default_factory=dict
)
updated_at: str = SchemaField(
description="The date and time the webset was updated"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="89ccd99a-3c2b-4fbf-9e25-0ffa398d0314",
description="Update metadata for an existing Webset",
categories={BlockCategory.SEARCH},
input_schema=ExaUpdateWebsetBlock.Input,
output_schema=ExaUpdateWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"Content-Type": "application/json",
"x-api-key": credentials.api_key.get_secret_value(),
}
# Build the payload
payload = {}
if input_data.metadata is not None:
payload["metadata"] = input_data.metadata
try:
response = await Requests().post(url, headers=headers, json=payload)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "metadata", data.get("metadata", {})
yield "updated_at", data.get("updatedAt", "")
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "metadata", {}
yield "updated_at", ""
class ExaListWebsetsBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
cursor: Optional[str] = SchemaField(
default=None,
description="Cursor for pagination through results",
advanced=True,
)
limit: int = SchemaField(
default=25,
description="Number of websets to return (1-100)",
ge=1,
le=100,
advanced=True,
)
class Output(BlockSchema):
websets: list = SchemaField(description="List of websets", default_factory=list)
has_more: bool = SchemaField(
description="Whether there are more results to paginate through",
default=False,
)
next_cursor: Optional[str] = SchemaField(
description="Cursor for the next page of results", default=None
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="1dcd8fd6-c13f-4e6f-bd4c-654428fa4757",
description="List all Websets with pagination support",
categories={BlockCategory.SEARCH},
input_schema=ExaListWebsetsBlock.Input,
output_schema=ExaListWebsetsBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = "https://api.exa.ai/websets/v0/websets"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
params: dict[str, Any] = {
"limit": input_data.limit,
}
if input_data.cursor:
params["cursor"] = input_data.cursor
try:
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "websets", data.get("data", [])
yield "has_more", data.get("hasMore", False)
yield "next_cursor", data.get("nextCursor")
except Exception as e:
yield "error", str(e)
yield "websets", []
yield "has_more", False
class ExaGetWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to retrieve",
placeholder="webset-id-or-external-id",
)
expand_items: bool = SchemaField(
default=False, description="Include items in the response", advanced=True
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(description="The status of the webset")
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
searches: list[dict] = SchemaField(
description="The searches performed on the webset", default_factory=list
)
enrichments: list[dict] = SchemaField(
description="The enrichments applied to the webset", default_factory=list
)
monitors: list[dict] = SchemaField(
description="The monitors for the webset", default_factory=list
)
items: Optional[list[dict]] = SchemaField(
description="The items in the webset (if expand_items is true)",
default=None,
)
metadata: dict = SchemaField(
description="Key-value pairs associated with the webset",
default_factory=dict,
)
created_at: str = SchemaField(
description="The date and time the webset was created"
)
updated_at: str = SchemaField(
description="The date and time the webset was last updated"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="6ab8e12a-132c-41bf-b5f3-d662620fa832",
description="Retrieve a Webset by ID or external ID",
categories={BlockCategory.SEARCH},
input_schema=ExaGetWebsetBlock.Input,
output_schema=ExaGetWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
params = {}
if input_data.expand_items:
params["expand[]"] = "items"
try:
response = await Requests().get(url, headers=headers, params=params)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "searches", data.get("searches", [])
yield "enrichments", data.get("enrichments", [])
yield "monitors", data.get("monitors", [])
yield "items", data.get("items")
yield "metadata", data.get("metadata", {})
yield "created_at", data.get("createdAt", "")
yield "updated_at", data.get("updatedAt", "")
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "searches", []
yield "enrichments", []
yield "monitors", []
yield "metadata", {}
yield "created_at", ""
yield "updated_at", ""
class ExaDeleteWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to delete",
placeholder="webset-id-or-external-id",
)
class Output(BlockSchema):
webset_id: str = SchemaField(
description="The unique identifier for the deleted webset"
)
external_id: Optional[str] = SchemaField(
description="The external identifier for the deleted webset", default=None
)
status: str = SchemaField(description="The status of the deleted webset")
success: str = SchemaField(
description="Whether the deletion was successful", default="true"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="aa6994a2-e986-421f-8d4c-7671d3be7b7e",
description="Delete a Webset and all its items",
categories={BlockCategory.SEARCH},
input_schema=ExaDeleteWebsetBlock.Input,
output_schema=ExaDeleteWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = await Requests().delete(url, headers=headers)
data = response.json()
yield "webset_id", data.get("id", "")
yield "external_id", data.get("externalId")
yield "status", data.get("status", "")
yield "success", "true"
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "success", "false"
class ExaCancelWebsetBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput = exa.credentials_field(
description="The Exa integration requires an API Key."
)
webset_id: str = SchemaField(
description="The ID or external ID of the Webset to cancel",
placeholder="webset-id-or-external-id",
)
class Output(BlockSchema):
webset_id: str = SchemaField(description="The unique identifier for the webset")
status: str = SchemaField(
description="The status of the webset after cancellation"
)
external_id: Optional[str] = SchemaField(
description="The external identifier for the webset", default=None
)
success: str = SchemaField(
description="Whether the cancellation was successful", default="true"
)
error: str = SchemaField(
description="Error message if the request failed", default=""
)
def __init__(self):
super().__init__(
id="e40a6420-1db8-47bb-b00a-0e6aecd74176",
description="Cancel all operations being performed on a Webset",
categories={BlockCategory.SEARCH},
input_schema=ExaCancelWebsetBlock.Input,
output_schema=ExaCancelWebsetBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
url = f"https://api.exa.ai/websets/v0/websets/{input_data.webset_id}/cancel"
headers = {
"x-api-key": credentials.api_key.get_secret_value(),
}
try:
response = await Requests().post(url, headers=headers)
data = response.json()
yield "webset_id", data.get("id", "")
yield "status", data.get("status", "")
yield "external_id", data.get("externalId")
yield "success", "true"
except Exception as e:
yield "error", str(e)
yield "webset_id", ""
yield "status", ""
yield "success", "false"

View File

@@ -0,0 +1,9 @@
# Import the provider builder to ensure it's registered
from backend.sdk.registry import AutoRegistry
from .triggers import GenericWebhookTriggerBlock, generic_webhook
# Ensure the SDK registry is patched to include our webhook manager
AutoRegistry.patch_integrations()
__all__ = ["GenericWebhookTriggerBlock", "generic_webhook"]

View File

@@ -3,10 +3,7 @@ import logging
from fastapi import Request
from strenum import StrEnum
from backend.data import integrations
from backend.integrations.providers import ProviderName
from ._manual_base import ManualWebhookManagerBase
from backend.sdk import ManualWebhookManagerBase, Webhook
logger = logging.getLogger(__name__)
@@ -16,12 +13,11 @@ class GenericWebhookType(StrEnum):
class GenericWebhooksManager(ManualWebhookManagerBase):
PROVIDER_NAME = ProviderName.GENERIC_WEBHOOK
WebhookType = GenericWebhookType
@classmethod
async def validate_payload(
cls, webhook: integrations.Webhook, request: Request
cls, webhook: Webhook, request: Request
) -> tuple[dict, str]:
payload = await request.json()
event_type = GenericWebhookType.PLAIN

View File

@@ -1,13 +1,21 @@
from backend.data.block import (
from backend.sdk import (
Block,
BlockCategory,
BlockManualWebhookConfig,
BlockOutput,
BlockSchema,
ProviderBuilder,
ProviderName,
SchemaField,
)
from ._webhook import GenericWebhooksManager, GenericWebhookType
generic_webhook = (
ProviderBuilder("generic_webhook")
.with_webhook_manager(GenericWebhooksManager)
.build()
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.generic import GenericWebhookType
class GenericWebhookTriggerBlock(Block):
@@ -36,7 +44,7 @@ class GenericWebhookTriggerBlock(Block):
input_schema=GenericWebhookTriggerBlock.Input,
output_schema=GenericWebhookTriggerBlock.Output,
webhook_config=BlockManualWebhookConfig(
provider=ProviderName.GENERIC_WEBHOOK,
provider=ProviderName(generic_webhook.name),
webhook_type=GenericWebhookType.PLAIN,
),
test_input={"constants": {"key": "value"}, "payload": self.example_payload},

View File

@@ -0,0 +1,14 @@
"""
Linear integration blocks for AutoGPT Platform.
"""
from .comment import LinearCreateCommentBlock
from .issues import LinearCreateIssueBlock, LinearSearchIssuesBlock
from .projects import LinearSearchProjectsBlock
__all__ = [
"LinearCreateCommentBlock",
"LinearCreateIssueBlock",
"LinearSearchIssuesBlock",
"LinearSearchProjectsBlock",
]

View File

@@ -1,16 +1,11 @@
from __future__ import annotations
import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from backend.blocks.linear._auth import LinearCredentials
from backend.blocks.linear.models import (
CreateCommentResponse,
CreateIssueResponse,
Issue,
Project,
)
from backend.util.request import Requests
from backend.sdk import APIKeyCredentials, OAuth2Credentials, Requests
from .models import CreateCommentResponse, CreateIssueResponse, Issue, Project
class LinearAPIException(Exception):
@@ -29,13 +24,12 @@ class LinearClient:
def __init__(
self,
credentials: LinearCredentials | None = None,
credentials: Union[OAuth2Credentials, APIKeyCredentials, None] = None,
custom_requests: Optional[Requests] = None,
):
if custom_requests:
self._requests = custom_requests
else:
headers: Dict[str, str] = {
"Content-Type": "application/json",
}

View File

@@ -1,31 +1,19 @@
"""
Shared configuration for all Linear blocks using the new SDK pattern.
"""
import os
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from backend.data.model import (
from backend.sdk import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
BlockCostType,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
LINEAR_OAUTH_IS_CONFIGURED = bool(
secrets.linear_client_id and secrets.linear_client_secret
ProviderBuilder,
SecretStr,
)
LinearCredentials = OAuth2Credentials | APIKeyCredentials
# LinearCredentialsInput = CredentialsMetaInput[
# Literal[ProviderName.LINEAR],
# Literal["oauth2", "api_key"] if LINEAR_OAUTH_IS_CONFIGURED else Literal["oauth2"],
# ]
LinearCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.LINEAR], Literal["oauth2"]
]
from ._oauth import LinearOAuthHandler
# (required) Comma separated list of scopes:
@@ -50,21 +38,35 @@ class LinearScope(str, Enum):
ADMIN = "admin"
def LinearCredentialsField(scopes: list[LinearScope]) -> LinearCredentialsInput:
"""
Creates a Linear credentials input on a block.
# Check if Linear OAuth is configured
client_id = os.getenv("LINEAR_CLIENT_ID")
client_secret = os.getenv("LINEAR_CLIENT_SECRET")
LINEAR_OAUTH_IS_CONFIGURED = bool(client_id and client_secret)
Params:
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
""" # noqa
return CredentialsField(
required_scopes=set([LinearScope.READ.value]).union(
set([scope.value for scope in scopes])
),
description="The Linear integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",
# Build the Linear provider
builder = (
ProviderBuilder("linear")
.with_api_key(env_var_name="LINEAR_API_KEY", title="Linear API Key")
.with_base_cost(1, BlockCostType.RUN)
)
# Linear only supports OAuth authentication
if LINEAR_OAUTH_IS_CONFIGURED:
builder = builder.with_oauth(
LinearOAuthHandler,
scopes=[
LinearScope.READ,
LinearScope.WRITE,
LinearScope.ISSUES_CREATE,
LinearScope.COMMENTS_CREATE,
],
client_id_env_var="LINEAR_CLIENT_ID",
client_secret_env_var="LINEAR_CLIENT_SECRET",
)
# Build the provider
linear = builder.build()
TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",

View File

@@ -1,15 +1,27 @@
"""
Linear OAuth handler implementation.
"""
import json
from typing import Optional
from urllib.parse import urlencode
from pydantic import SecretStr
from backend.sdk import (
APIKeyCredentials,
BaseOAuthHandler,
OAuth2Credentials,
ProviderName,
Requests,
SecretStr,
)
from backend.blocks.linear._api import LinearAPIException
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
from .base import BaseOAuthHandler
class LinearAPIException(Exception):
"""Exception for Linear API errors."""
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code
class LinearOAuthHandler(BaseOAuthHandler):
@@ -17,7 +29,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
OAuth2 handler for Linear.
"""
PROVIDER_NAME = ProviderName.LINEAR
# Provider name will be set dynamically by the SDK when registered
# We use a placeholder that will be replaced by AutoRegistry.register_provider()
PROVIDER_NAME = ProviderName("linear")
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
@@ -30,7 +44,6 @@ class LinearOAuthHandler(BaseOAuthHandler):
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@@ -139,9 +152,10 @@ class LinearOAuthHandler(BaseOAuthHandler):
async def _request_username(self, access_token: str) -> Optional[str]:
# Use the LinearClient to fetch user details using GraphQL
from backend.blocks.linear._api import LinearClient
from ._api import LinearClient
try:
# Create a temporary OAuth2Credentials object for the LinearClient
linear_client = LinearClient(
APIKeyCredentials(
api_key=SecretStr(access_token),

View File

@@ -1,24 +1,32 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
linear,
)
from backend.blocks.linear.models import CreateCommentResponse
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from .models import CreateCommentResponse
class LinearCreateCommentBlock(Block):
"""Block for creating comments on Linear issues"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.COMMENTS_CREATE],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with comment creation permissions",
required_scopes={LinearScope.COMMENTS_CREATE},
)
issue_id: str = SchemaField(description="ID of the issue to comment on")
comment: str = SchemaField(description="Comment text to add to the issue")
@@ -55,7 +63,7 @@ class LinearCreateCommentBlock(Block):
@staticmethod
async def create_comment(
credentials: LinearCredentials, issue_id: str, comment: str
credentials: OAuth2Credentials | APIKeyCredentials, issue_id: str, comment: str
) -> tuple[str, str]:
client = LinearClient(credentials=credentials)
response: CreateCommentResponse = await client.try_create_comment(
@@ -64,7 +72,11 @@ class LinearCreateCommentBlock(Block):
return response.comment.id, response.comment.body
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the comment creation"""
try:

View File

@@ -1,24 +1,32 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
linear,
)
from backend.blocks.linear.models import CreateIssueResponse, Issue
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from .models import CreateIssueResponse, Issue
class LinearCreateIssueBlock(Block):
"""Block for creating issues on Linear"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.ISSUES_CREATE],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with issue creation permissions",
required_scopes={LinearScope.ISSUES_CREATE},
)
title: str = SchemaField(description="Title of the issue")
description: str | None = SchemaField(description="Description of the issue")
@@ -68,7 +76,7 @@ class LinearCreateIssueBlock(Block):
@staticmethod
async def create_issue(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
team_name: str,
title: str,
description: str | None = None,
@@ -94,7 +102,11 @@ class LinearCreateIssueBlock(Block):
return response.issue.identifier, response.issue.title
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
"""Execute the issue creation"""
try:
@@ -121,8 +133,9 @@ class LinearSearchIssuesBlock(Block):
class Input(BlockSchema):
term: str = SchemaField(description="Term to search for issues")
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.READ],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with read permissions",
required_scopes={LinearScope.READ},
)
class Output(BlockSchema):
@@ -169,7 +182,7 @@ class LinearSearchIssuesBlock(Block):
@staticmethod
async def search_issues(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
) -> list[Issue]:
client = LinearClient(credentials=credentials)
@@ -177,7 +190,11 @@ class LinearSearchIssuesBlock(Block):
return response
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the issue search"""
try:

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel
from backend.sdk import BaseModel
class Comment(BaseModel):

View File

@@ -1,24 +1,32 @@
from backend.blocks.linear._api import LinearAPIException, LinearClient
from backend.blocks.linear._auth import (
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from ._api import LinearAPIException, LinearClient
from ._config import (
LINEAR_OAUTH_IS_CONFIGURED,
TEST_CREDENTIALS_INPUT_OAUTH,
TEST_CREDENTIALS_OAUTH,
LinearCredentials,
LinearCredentialsField,
LinearCredentialsInput,
LinearScope,
linear,
)
from backend.blocks.linear.models import Project
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from .models import Project
class LinearSearchProjectsBlock(Block):
"""Block for searching projects on Linear"""
class Input(BlockSchema):
credentials: LinearCredentialsInput = LinearCredentialsField(
scopes=[LinearScope.READ],
credentials: CredentialsMetaInput = linear.credentials_field(
description="Linear credentials with read permissions",
required_scopes={LinearScope.READ},
)
term: str = SchemaField(description="Term to search for projects")
@@ -70,7 +78,7 @@ class LinearSearchProjectsBlock(Block):
@staticmethod
async def search_projects(
credentials: LinearCredentials,
credentials: OAuth2Credentials | APIKeyCredentials,
term: str,
) -> list[Project]:
client = LinearClient(credentials=credentials)
@@ -78,7 +86,11 @@ class LinearSearchProjectsBlock(Block):
return response
async def run(
self, input_data: Input, *, credentials: LinearCredentials, **kwargs
self,
input_data: Input,
*,
credentials: OAuth2Credentials | APIKeyCredentials,
**kwargs,
) -> BlockOutput:
"""Execute the project search"""
try:

View File

@@ -9,3 +9,117 @@ from backend.util.test import execute_block_test
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
async def test_available_blocks(block: Type[Block]):
await execute_block_test(block())
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
async def test_block_ids_valid(block: Type[Block]):
# add the tests here to check they are uuid4
import uuid
# Skip list for blocks with known invalid UUIDs
skip_blocks = {
"GetWeatherInformationBlock",
"CodeExecutionBlock",
"CountdownTimerBlock",
"TwitterGetListTweetsBlock",
"TwitterRemoveListMemberBlock",
"TwitterAddListMemberBlock",
"TwitterGetListMembersBlock",
"TwitterGetListMembershipsBlock",
"TwitterUnfollowListBlock",
"TwitterFollowListBlock",
"TwitterUnpinListBlock",
"TwitterPinListBlock",
"TwitterGetPinnedListsBlock",
"TwitterDeleteListBlock",
"TwitterUpdateListBlock",
"TwitterCreateListBlock",
"TwitterGetListBlock",
"TwitterGetOwnedListsBlock",
"TwitterGetSpacesBlock",
"TwitterGetSpaceByIdBlock",
"TwitterGetSpaceBuyersBlock",
"TwitterGetSpaceTweetsBlock",
"TwitterSearchSpacesBlock",
"TwitterGetUserMentionsBlock",
"TwitterGetHomeTimelineBlock",
"TwitterGetUserTweetsBlock",
"TwitterGetTweetBlock",
"TwitterGetTweetsBlock",
"TwitterGetQuoteTweetsBlock",
"TwitterLikeTweetBlock",
"TwitterGetLikingUsersBlock",
"TwitterGetLikedTweetsBlock",
"TwitterUnlikeTweetBlock",
"TwitterBookmarkTweetBlock",
"TwitterGetBookmarkedTweetsBlock",
"TwitterRemoveBookmarkTweetBlock",
"TwitterRetweetBlock",
"TwitterRemoveRetweetBlock",
"TwitterGetRetweetersBlock",
"TwitterHideReplyBlock",
"TwitterUnhideReplyBlock",
"TwitterPostTweetBlock",
"TwitterDeleteTweetBlock",
"TwitterSearchRecentTweetsBlock",
"TwitterUnfollowUserBlock",
"TwitterFollowUserBlock",
"TwitterGetFollowersBlock",
"TwitterGetFollowingBlock",
"TwitterUnmuteUserBlock",
"TwitterGetMutedUsersBlock",
"TwitterMuteUserBlock",
"TwitterGetBlockedUsersBlock",
"TwitterGetUserBlock",
"TwitterGetUsersBlock",
"TodoistCreateLabelBlock",
"TodoistListLabelsBlock",
"TodoistGetLabelBlock",
"TodoistUpdateLabelBlock",
"TodoistDeleteLabelBlock",
"TodoistGetSharedLabelsBlock",
"TodoistRenameSharedLabelsBlock",
"TodoistRemoveSharedLabelsBlock",
"TodoistCreateTaskBlock",
"TodoistGetTasksBlock",
"TodoistGetTaskBlock",
"TodoistUpdateTaskBlock",
"TodoistCloseTaskBlock",
"TodoistReopenTaskBlock",
"TodoistDeleteTaskBlock",
"TodoistListSectionsBlock",
"TodoistGetSectionBlock",
"TodoistDeleteSectionBlock",
"TodoistCreateProjectBlock",
"TodoistGetProjectBlock",
"TodoistUpdateProjectBlock",
"TodoistDeleteProjectBlock",
"TodoistListCollaboratorsBlock",
"TodoistGetCommentsBlock",
"TodoistGetCommentBlock",
"TodoistUpdateCommentBlock",
"TodoistDeleteCommentBlock",
"GithubListStargazersBlock",
"Slant3DSlicerBlock",
}
block_instance = block()
# Skip blocks with known invalid UUIDs
if block_instance.__class__.__name__ in skip_blocks:
pytest.skip(
f"Skipping UUID check for {block_instance.__class__.__name__} - known invalid UUID"
)
# Check that the ID is not empty
assert block_instance.id, f"Block {block.name} has empty ID"
# Check that the ID is a valid UUID4
try:
parsed_uuid = uuid.UUID(block_instance.id)
# Verify it's specifically UUID version 4
assert (
parsed_uuid.version == 4
), f"Block {block.name} ID is UUID version {parsed_uuid.version}, expected version 4"
except ValueError:
pytest.fail(f"Block {block.name} has invalid UUID format: {block_instance.id}")

View File

@@ -0,0 +1,359 @@
import asyncio
import random
from datetime import datetime
from faker import Faker
from prisma import Prisma
faker = Faker()
async def check_cron_job(db):
"""Check if the pg_cron job for refreshing materialized views exists."""
print("\n1. Checking pg_cron job...")
print("-" * 40)
try:
# Check if pg_cron extension exists
extension_check = await db.query_raw("CREATE EXTENSION pg_cron;")
print(extension_check)
extension_check = await db.query_raw(
"SELECT COUNT(*) as count FROM pg_extension WHERE extname = 'pg_cron'"
)
if extension_check[0]["count"] == 0:
print("⚠️ pg_cron extension is not installed")
return False
# Check if the refresh job exists
job_check = await db.query_raw(
"""
SELECT jobname, schedule, command
FROM cron.job
WHERE jobname = 'refresh-store-views'
"""
)
if job_check:
job = job_check[0]
print("✅ pg_cron job found:")
print(f" Name: {job['jobname']}")
print(f" Schedule: {job['schedule']} (every 15 minutes)")
print(f" Command: {job['command']}")
return True
else:
print("⚠️ pg_cron job 'refresh-store-views' not found")
return False
except Exception as e:
print(f"❌ Error checking pg_cron: {e}")
return False
async def get_materialized_view_counts(db):
"""Get current counts from materialized views."""
print("\n2. Getting current materialized view data...")
print("-" * 40)
# Get counts from mv_agent_run_counts
agent_runs = await db.query_raw(
"""
SELECT COUNT(*) as total_agents,
SUM(run_count) as total_runs,
MAX(run_count) as max_runs,
MIN(run_count) as min_runs
FROM mv_agent_run_counts
"""
)
# Get counts from mv_review_stats
review_stats = await db.query_raw(
"""
SELECT COUNT(*) as total_listings,
SUM(review_count) as total_reviews,
AVG(avg_rating) as overall_avg_rating
FROM mv_review_stats
"""
)
# Get sample data from StoreAgent view
store_agents = await db.query_raw(
"""
SELECT COUNT(*) as total_store_agents,
AVG(runs) as avg_runs,
AVG(rating) as avg_rating
FROM "StoreAgent"
"""
)
agent_run_data = agent_runs[0] if agent_runs else {}
review_data = review_stats[0] if review_stats else {}
store_data = store_agents[0] if store_agents else {}
print("📊 mv_agent_run_counts:")
print(f" Total agents: {agent_run_data.get('total_agents', 0)}")
print(f" Total runs: {agent_run_data.get('total_runs', 0)}")
print(f" Max runs per agent: {agent_run_data.get('max_runs', 0)}")
print(f" Min runs per agent: {agent_run_data.get('min_runs', 0)}")
print("\n📊 mv_review_stats:")
print(f" Total listings: {review_data.get('total_listings', 0)}")
print(f" Total reviews: {review_data.get('total_reviews', 0)}")
print(f" Overall avg rating: {review_data.get('overall_avg_rating') or 0:.2f}")
print("\n📊 StoreAgent view:")
print(f" Total store agents: {store_data.get('total_store_agents', 0)}")
print(f" Average runs: {store_data.get('avg_runs') or 0:.2f}")
print(f" Average rating: {store_data.get('avg_rating') or 0:.2f}")
return {
"agent_runs": agent_run_data,
"reviews": review_data,
"store_agents": store_data,
}
async def add_test_data(db):
"""Add some test data to verify materialized view updates."""
print("\n3. Adding test data...")
print("-" * 40)
# Get some existing data
users = await db.user.find_many(take=5)
graphs = await db.agentgraph.find_many(take=5)
if not users or not graphs:
print("❌ No existing users or graphs found. Run test_data_creator.py first.")
return False
# Add new executions
print("Adding new agent graph executions...")
new_executions = 0
for graph in graphs:
for _ in range(random.randint(2, 5)):
await db.agentgraphexecution.create(
data={
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"userId": random.choice(users).id,
"executionStatus": "COMPLETED",
"startedAt": datetime.now(),
}
)
new_executions += 1
print(f"✅ Added {new_executions} new executions")
# Check if we need to create store listings first
store_versions = await db.storelistingversion.find_many(
where={"submissionStatus": "APPROVED"}, take=5
)
if not store_versions:
print("\nNo approved store listings found. Creating test store listings...")
# Create store listings for existing agent graphs
for i, graph in enumerate(graphs[:3]): # Create up to 3 store listings
# Create a store listing
listing = await db.storelisting.create(
data={
"slug": f"test-agent-{graph.id[:8]}",
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"hasApprovedVersion": True,
"owningUserId": graph.userId,
}
)
# Create an approved version
version = await db.storelistingversion.create(
data={
"storeListingId": listing.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"name": f"Test Agent {i+1}",
"subHeading": faker.catch_phrase(),
"description": faker.paragraph(nb_sentences=5),
"imageUrls": [faker.image_url()],
"categories": ["productivity", "automation"],
"submissionStatus": "APPROVED",
"submittedAt": datetime.now(),
}
)
# Update listing with active version
await db.storelisting.update(
where={"id": listing.id}, data={"activeVersionId": version.id}
)
print("✅ Created test store listings")
# Re-fetch approved versions
store_versions = await db.storelistingversion.find_many(
where={"submissionStatus": "APPROVED"}, take=5
)
# Add new reviews
print("\nAdding new store listing reviews...")
new_reviews = 0
for version in store_versions:
# Find users who haven't reviewed this version
existing_reviews = await db.storelistingreview.find_many(
where={"storeListingVersionId": version.id}
)
reviewed_user_ids = {r.reviewByUserId for r in existing_reviews}
available_users = [u for u in users if u.id not in reviewed_user_ids]
if available_users:
user = random.choice(available_users)
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": user.id,
"score": random.randint(3, 5),
"comments": faker.text(max_nb_chars=100),
}
)
new_reviews += 1
print(f"✅ Added {new_reviews} new reviews")
return True
async def refresh_materialized_views(db):
"""Manually refresh the materialized views."""
print("\n4. Manually refreshing materialized views...")
print("-" * 40)
try:
await db.execute_raw("SELECT refresh_store_materialized_views();")
print("✅ Materialized views refreshed successfully")
return True
except Exception as e:
print(f"❌ Error refreshing views: {e}")
return False
async def compare_counts(before, after):
"""Compare counts before and after refresh."""
print("\n5. Comparing counts before and after refresh...")
print("-" * 40)
# Compare agent runs
print("🔍 Agent run changes:")
before_runs = before["agent_runs"].get("total_runs") or 0
after_runs = after["agent_runs"].get("total_runs") or 0
print(
f" Total runs: {before_runs}{after_runs} " f"(+{after_runs - before_runs})"
)
# Compare reviews
print("\n🔍 Review changes:")
before_reviews = before["reviews"].get("total_reviews") or 0
after_reviews = after["reviews"].get("total_reviews") or 0
print(
f" Total reviews: {before_reviews}{after_reviews} "
f"(+{after_reviews - before_reviews})"
)
# Compare store agents
print("\n🔍 StoreAgent view changes:")
before_avg_runs = before["store_agents"].get("avg_runs", 0) or 0
after_avg_runs = after["store_agents"].get("avg_runs", 0) or 0
print(
f" Average runs: {before_avg_runs:.2f}{after_avg_runs:.2f} "
f"(+{after_avg_runs - before_avg_runs:.2f})"
)
# Verify changes occurred
runs_changed = (after["agent_runs"].get("total_runs") or 0) > (
before["agent_runs"].get("total_runs") or 0
)
reviews_changed = (after["reviews"].get("total_reviews") or 0) > (
before["reviews"].get("total_reviews") or 0
)
if runs_changed and reviews_changed:
print("\n✅ Materialized views are updating correctly!")
return True
else:
print("\n⚠️ Some materialized views may not have updated:")
if not runs_changed:
print(" - Agent run counts did not increase")
if not reviews_changed:
print(" - Review counts did not increase")
return False
async def main():
db = Prisma()
await db.connect()
print("=" * 60)
print("Materialized Views Test")
print("=" * 60)
try:
# Check if data exists
user_count = await db.user.count()
if user_count == 0:
print("❌ No data in database. Please run test_data_creator.py first.")
await db.disconnect()
return
# 1. Check cron job
cron_exists = await check_cron_job(db)
# 2. Get initial counts
counts_before = await get_materialized_view_counts(db)
# 3. Add test data
data_added = await add_test_data(db)
refresh_success = False
if data_added:
# Wait a moment for data to be committed
print("\nWaiting for data to be committed...")
await asyncio.sleep(2)
# 4. Manually refresh views
refresh_success = await refresh_materialized_views(db)
if refresh_success:
# 5. Get counts after refresh
counts_after = await get_materialized_view_counts(db)
# 6. Compare results
await compare_counts(counts_before, counts_after)
# Summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
print(f"✓ pg_cron job exists: {'Yes' if cron_exists else 'No'}")
print(f"✓ Test data added: {'Yes' if data_added else 'No'}")
print(f"✓ Manual refresh worked: {'Yes' if refresh_success else 'No'}")
print(
f"✓ Views updated correctly: {'Yes' if data_added and refresh_success else 'Cannot verify'}"
)
if cron_exists:
print(
"\n💡 The materialized views will also refresh automatically every 15 minutes via pg_cron."
)
else:
print(
"\n⚠️ Automatic refresh is not configured. Views must be refreshed manually."
)
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
await db.disconnect()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""Check store-related data in the database."""
import asyncio
from prisma import Prisma
async def check_store_data(db):
"""Check what store data exists in the database."""
print("============================================================")
print("Store Data Check")
print("============================================================")
# Check store listings
print("\n1. Store Listings:")
print("-" * 40)
listings = await db.storelisting.find_many()
print(f"Total store listings: {len(listings)}")
if listings:
for listing in listings[:5]:
print(f"\nListing ID: {listing.id}")
print(f" Name: {listing.name}")
print(f" Status: {listing.status}")
print(f" Slug: {listing.slug}")
# Check store listing versions
print("\n\n2. Store Listing Versions:")
print("-" * 40)
versions = await db.storelistingversion.find_many(include={"StoreListing": True})
print(f"Total store listing versions: {len(versions)}")
# Group by submission status
status_counts = {}
for version in versions:
status = version.submissionStatus
status_counts[status] = status_counts.get(status, 0) + 1
print("\nVersions by status:")
for status, count in status_counts.items():
print(f" {status}: {count}")
# Show approved versions
approved_versions = [v for v in versions if v.submissionStatus == "APPROVED"]
print(f"\nApproved versions: {len(approved_versions)}")
if approved_versions:
for version in approved_versions[:5]:
print(f"\n Version ID: {version.id}")
print(f" Listing: {version.StoreListing.name}")
print(f" Version: {version.version}")
# Check store listing reviews
print("\n\n3. Store Listing Reviews:")
print("-" * 40)
reviews = await db.storelistingreview.find_many(
include={"StoreListingVersion": {"include": {"StoreListing": True}}}
)
print(f"Total reviews: {len(reviews)}")
if reviews:
# Calculate average rating
total_score = sum(r.score for r in reviews)
avg_score = total_score / len(reviews) if reviews else 0
print(f"Average rating: {avg_score:.2f}")
# Show sample reviews
print("\nSample reviews:")
for review in reviews[:3]:
print(f"\n Review for: {review.StoreListingVersion.StoreListing.name}")
print(f" Score: {review.score}")
print(f" Comments: {review.comments[:100]}...")
# Check StoreAgent view data
print("\n\n4. StoreAgent View Data:")
print("-" * 40)
# Query the StoreAgent view
query = """
SELECT
sa.listing_id,
sa.slug,
sa.agent_name,
sa.description,
sa.featured,
sa.runs,
sa.rating,
sa.creator_username,
sa.categories,
sa.updated_at
FROM "StoreAgent" sa
LIMIT 10;
"""
store_agents = await db.query_raw(query)
print(f"Total store agents in view: {len(store_agents)}")
if store_agents:
for agent in store_agents[:5]:
print(f"\nStore Agent: {agent['agent_name']}")
print(f" Slug: {agent['slug']}")
print(f" Runs: {agent['runs']}")
print(f" Rating: {agent['rating']}")
print(f" Creator: {agent['creator_username']}")
# Check the underlying data that should populate StoreAgent
print("\n\n5. Data that should populate StoreAgent view:")
print("-" * 40)
# Check for any APPROVED store listing versions
query = """
SELECT COUNT(*) as count
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
"""
result = await db.query_raw(query)
approved_count = result[0]["count"] if result else 0
print(f"Approved store listing versions: {approved_count}")
# Check for store listings with hasApprovedVersion = true
query = """
SELECT COUNT(*) as count
FROM "StoreListing"
WHERE "hasApprovedVersion" = true AND "isDeleted" = false
"""
result = await db.query_raw(query)
has_approved_count = result[0]["count"] if result else 0
print(f"Store listings with approved versions: {has_approved_count}")
# Check agent graph executions
query = """
SELECT COUNT(DISTINCT "agentGraphId") as unique_agents,
COUNT(*) as total_executions
FROM "AgentGraphExecution"
"""
result = await db.query_raw(query)
if result:
print("\nAgent Graph Executions:")
print(f" Unique agents with executions: {result[0]['unique_agents']}")
print(f" Total executions: {result[0]['total_executions']}")
async def main():
"""Main function."""
db = Prisma()
await db.connect()
try:
await check_store_data(db)
finally:
await db.disconnect()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -425,28 +425,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
stats_dict = stats.model_dump()
current_stats = self.execution_stats.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it, but this will probably
# not happen, just in case though so we throw for invalid when
# converting back in
current_stats[key] = value
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
current_stats[key] += value
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
else:
current_stats[key] = value
self.execution_stats = NodeExecutionStats(**current_stats)
self.execution_stats += stats
return self.execution_stats
@property
@@ -513,6 +492,12 @@ def get_blocks() -> dict[str, Type[Block]]:
async def initialize_blocks() -> None:
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs
sync_all_provider_costs()
for cls in get_blocks().values():
block = cls()
existing_block = await AgentBlock.prisma().find_first(

View File

@@ -93,6 +93,28 @@ async def locked_transaction(key: str):
yield tx
def get_database_schema() -> str:
"""Extract database schema from DATABASE_URL."""
parsed_url = urlparse(DATABASE_URL)
query_params = dict(parse_qsl(parsed_url.query))
return query_params.get("schema", "public")
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
"""Execute raw SQL query with proper schema handling."""
schema = get_database_schema()
schema_prefix = f"{schema}." if schema != "public" else ""
formatted_query = query_template.format(schema_prefix=schema_prefix)
import prisma as prisma_module
result = await prisma_module.get_client().query_raw(
formatted_query, *args # type: ignore
)
return result
class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))

View File

@@ -49,7 +49,7 @@ from .block import (
get_io_block_ids,
get_webhook_block_ids,
)
from .db import BaseDbModel
from .db import BaseDbModel, query_raw_with_schema
from .event_bus import AsyncRedisEventBus, RedisEventBus
from .includes import (
EXECUTION_RESULT_INCLUDE,
@@ -68,6 +68,21 @@ config = Config()
# -------------------------- Models -------------------------- #
class BlockErrorStats(BaseModel):
"""Typed data structure for block error statistics."""
block_id: str
total_executions: int
failed_executions: int
@property
def error_rate(self) -> float:
"""Calculate error rate as a percentage."""
if self.total_executions == 0:
return 0.0
return (self.failed_executions / self.total_executions) * 100
ExecutionStatus = AgentExecutionStatus
@@ -357,6 +372,7 @@ async def get_graph_executions(
created_time_lte: datetime | None = None,
limit: int | None = None,
) -> list[GraphExecutionMeta]:
"""⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints."""
where_filter: AgentGraphExecutionWhereInput = {
"isDeleted": False,
}
@@ -722,6 +738,7 @@ async def delete_graph_execution(
async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
execution = await AgentNodeExecution.prisma().find_first(
where={"id": node_exec_id},
include=EXECUTION_RESULT_INCLUDE,
@@ -732,15 +749,19 @@ async def get_node_execution(node_exec_id: str) -> NodeExecutionResult | None:
async def get_node_executions(
graph_exec_id: str,
graph_exec_id: str | None = None,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
created_time_gte: datetime | None = None,
created_time_lte: datetime | None = None,
include_exec_data: bool = True,
) -> list[NodeExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
where_clause: AgentNodeExecutionWhereInput = {}
if graph_exec_id:
where_clause["agentGraphExecutionId"] = graph_exec_id
if node_id:
where_clause["agentNodeId"] = node_id
if block_ids:
@@ -748,9 +769,19 @@ async def get_node_executions(
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
if created_time_gte or created_time_lte:
where_clause["addedTime"] = {
"gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc),
"lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc),
}
executions = await AgentNodeExecution.prisma().find_many(
where=where_clause,
include=EXECUTION_RESULT_INCLUDE,
include=(
EXECUTION_RESULT_INCLUDE
if include_exec_data
else {"Node": True, "GraphExecution": True}
),
order=EXECUTION_RESULT_ORDER,
take=limit,
)
@@ -761,6 +792,7 @@ async def get_node_executions(
async def get_latest_node_execution(
node_id: str, graph_eid: str
) -> NodeExecutionResult | None:
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
execution = await AgentNodeExecution.prisma().find_first(
where={
"agentGraphExecutionId": graph_eid,
@@ -963,3 +995,33 @@ async def set_execution_kv_data(
},
)
return type_utils.convert(resp.data, type[Any]) if resp and resp.data else None
async def get_block_error_stats(
start_time: datetime, end_time: datetime
) -> list[BlockErrorStats]:
"""Get block execution stats using efficient SQL aggregation."""
query_template = """
SELECT
n."agentBlockId" as block_id,
COUNT(*) as total_executions,
SUM(CASE WHEN ne."executionStatus" = 'FAILED' THEN 1 ELSE 0 END) as failed_executions
FROM {schema_prefix}"AgentNodeExecution" ne
JOIN {schema_prefix}"AgentNode" n ON ne."agentNodeId" = n.id
WHERE ne."addedTime" >= $1::timestamp AND ne."addedTime" <= $2::timestamp
GROUP BY n."agentBlockId"
HAVING COUNT(*) >= 10
"""
result = await query_raw_with_schema(query_template, start_time, end_time)
# Convert to typed data structures
return [
BlockErrorStats(
block_id=row["block_id"],
total_executions=int(row["total_executions"]),
failed_executions=int(row["failed_executions"]),
)
for row in result
]

View File

@@ -3,7 +3,6 @@ import uuid
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import prisma
from prisma import Json
from prisma.enums import SubmissionStatus
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, StoreListingVersion
@@ -14,7 +13,7 @@ from prisma.types import (
AgentNodeLinkCreateInput,
StoreListingVersionWhereInput,
)
from pydantic import JsonValue, create_model
from pydantic import Field, JsonValue, create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
@@ -31,7 +30,7 @@ from backend.integrations.providers import ProviderName
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .db import BaseDbModel, query_raw_with_schema, transaction
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
if TYPE_CHECKING:
@@ -189,6 +188,23 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def has_external_trigger(self) -> bool:
return self.webhook_input_node is not None
@property
def webhook_input_node(self) -> Node | None:
return next(
(
node
for node in self.nodes
if node.block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
),
None,
)
@staticmethod
def _generate_schema(
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
@@ -326,11 +342,6 @@ class GraphModel(Graph):
user_id: str
nodes: list[NodeModel] = [] # type: ignore
@computed_field
@property
def has_webhook_trigger(self) -> bool:
return self.webhook_input_node is not None
@property
def starting_nodes(self) -> list[NodeModel]:
outbound_nodes = {link.sink_id for link in self.links}
@@ -343,17 +354,12 @@ class GraphModel(Graph):
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def webhook_input_node(self) -> NodeModel | None:
return next(
(
node
for node in self.nodes
if node.block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
),
None,
)
def meta(self) -> "GraphMeta":
"""
Returns a GraphMeta object with metadata about the graph.
This is used to return metadata about the graph without exposing nodes and links.
"""
return GraphMeta.from_graph(self)
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
"""
@@ -612,6 +618,18 @@ class GraphModel(Graph):
)
class GraphMeta(Graph):
user_id: str
# Easy work-around to prevent exposing nodes and links in the API response
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
links: list[Link] = Field(default=[], exclude=True)
@staticmethod
def from_graph(graph: GraphModel) -> "GraphMeta":
return GraphMeta(**graph.model_dump())
# --------------------- CRUD functions --------------------- #
@@ -640,10 +658,10 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:
return NodeModel.from_db(node)
async def get_graphs(
async def list_graphs(
user_id: str,
filter_by: Literal["active"] | None = "active",
) -> list[GraphModel]:
) -> list[GraphMeta]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
@@ -653,7 +671,7 @@ async def get_graphs(
user_id: The ID of the user that owns the graph.
Returns:
list[GraphModel]: A list of objects representing the retrieved graphs.
list[GraphMeta]: A list of objects representing the retrieved graphs.
"""
where_clause: AgentGraphWhereInput = {"userId": user_id}
@@ -667,13 +685,13 @@ async def get_graphs(
include=AGENT_GRAPH_INCLUDE,
)
graph_models = []
graph_models: list[GraphMeta] = []
for graph in graphs:
try:
graph_model = GraphModel.from_db(graph)
# Trigger serialization to validate that the graph is well formed.
graph_model.model_dump()
graph_models.append(graph_model)
graph_meta = GraphModel.from_db(graph).meta()
# Trigger serialization to validate that the graph is well formed
graph_meta.model_dump()
graph_models.append(graph_meta)
except Exception as e:
logger.error(f"Error processing graph {graph.id}: {e}")
continue
@@ -1040,13 +1058,13 @@ async def fix_llm_provider_credentials():
broken_nodes = []
try:
broken_nodes = await prisma.get_client().query_raw(
broken_nodes = await query_raw_with_schema(
"""
SELECT graph."userId" user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentNode" node
LEFT JOIN platform."AgentGraph" graph
FROM {schema_prefix}"AgentNode" node
LEFT JOIN {schema_prefix}"AgentGraph" graph
ON node."agentGraphId" = graph.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY graph."userId";

View File

@@ -42,6 +42,9 @@ from pydantic_core import (
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
AnyProviderName = str # Will be validated as ProviderName at runtime
if TYPE_CHECKING:
from backend.data.block import BlockSchema
@@ -341,7 +344,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
type: CT
@classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...]:
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
return get_args(cls.model_fields["provider"].annotation)
@classmethod
@@ -366,7 +369,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
f"{field_schema}"
) from e
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
providers = cls.allowed_providers()
if (
providers is not None
and len(providers) > 1
and not schema_extra.discriminator
):
raise TypeError(
f"Multi-provider CredentialsField '{field_name}' "
"requires discriminator!"
@@ -378,7 +386,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
if hasattr(model_class, "allowed_providers") and hasattr(
model_class, "allowed_cred_types"
):
schema["credentials_provider"] = model_class.allowed_providers()
allowed_providers = model_class.allowed_providers()
# If no specific providers (None), allow any string
if allowed_providers is None:
schema["credentials_provider"] = ["string"] # Allow any string provider
else:
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# Do not return anything, just mutate schema in place
@@ -540,6 +553,11 @@ def CredentialsField(
if v is not None
}
# Merge any json_schema_extra passed in kwargs
if "json_schema_extra" in kwargs:
extra_schema = kwargs.pop("json_schema_extra")
field_schema_extra.update(extra_schema)
return Field(
title=title,
description=description,
@@ -618,6 +636,35 @@ class NodeExecutionStats(BaseModel):
llm_retry_count: int = 0
input_token_count: int = 0
output_token_count: int = 0
extra_cost: int = 0
extra_steps: int = 0
def __iadd__(self, other: "NodeExecutionStats") -> "NodeExecutionStats":
"""Mutate this instance by adding another NodeExecutionStats."""
if not isinstance(other, NodeExecutionStats):
return NotImplemented
stats_dict = other.model_dump()
current_stats = self.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it
setattr(self, key, value)
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
setattr(self, key, current_stats[key])
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
setattr(self, key, current_stats[key] + value)
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
setattr(self, key, current_stats[key])
else:
setattr(self, key, value)
return self
class GraphExecutionStats(BaseModel):

View File

@@ -5,6 +5,7 @@ from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
create_graph_execution,
get_block_error_stats,
get_execution_kv_data,
get_graph_execution,
get_graph_execution_meta,
@@ -105,6 +106,7 @@ class DatabaseManager(AppService):
upsert_execution_output = _(upsert_execution_output)
get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats)
# Graphs
get_node = _(get_node)
@@ -199,6 +201,9 @@ class DatabaseManagerClient(AppServiceClient):
d.get_user_notification_oldest_message_in_batch
)
# Block error monitoring
get_block_error_stats = _(d.get_block_error_stats)
class DatabaseManagerAsyncClient(AppServiceClient):
d = DatabaseManager
@@ -226,3 +231,4 @@ class DatabaseManagerAsyncClient(AppServiceClient):
update_user_integrations = d.update_user_integrations
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data
get_block_error_stats = d.get_block_error_stats

View File

@@ -207,9 +207,7 @@ async def execute_node(
# Update execution stats
if execution_stats is not None:
execution_stats = execution_stats.model_copy(
update=node_block.execution_stats.model_dump()
)
execution_stats += node_block.execution_stats
execution_stats.input_size = input_size
execution_stats.output_size = output_size
@@ -648,9 +646,10 @@ class Executor:
return
nonlocal execution_stats
execution_stats.node_count += 1
execution_stats.node_count += 1 + result.extra_steps
execution_stats.nodes_cputime += result.cputime
execution_stats.nodes_walltime += result.walltime
execution_stats.cost += result.extra_cost
if (err := result.error) and isinstance(err, Exception):
execution_stats.node_error_count += 1
update_node_execution_status(
@@ -877,6 +876,7 @@ class Executor:
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in inflight_executions],

View File

@@ -1,7 +1,6 @@
import asyncio
import logging
import os
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
@@ -14,25 +13,23 @@ from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached
from dotenv import load_dotenv
from prisma.enums import NotificationType
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.data.execution import ExecutionStatus
from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.monitoring import (
NotificationJobArgs,
process_existing_batches,
process_weekly_summary,
report_block_error_rates,
report_late_executions,
)
from backend.util.exceptions import NotAuthorizedError, NotFoundError
from backend.util.logging import PrefixFilter
from backend.util.metrics import sentry_capture_error
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.service import AppService, AppServiceClient, endpoint_to_async, expose
from backend.util.settings import Config
@@ -71,11 +68,6 @@ def job_listener(event):
logger.info(f"Job {event.job_id} completed successfully.")
@thread_cached
def get_notification_client():
return get_service_client(NotificationManagerClient)
@thread_cached
def get_event_loop():
return asyncio.new_event_loop()
@@ -89,7 +81,7 @@ async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs)
try:
logger.info(f"Executing recurring job for graph #{args.graph_id}")
await execution_utils.add_graph_execution(
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
user_id=args.user_id,
graph_id=args.graph_id,
graph_version=args.graph_version,
@@ -97,65 +89,14 @@ async def _execute_graph(**kwargs):
graph_credentials_inputs=args.input_credentials,
use_db_query=False,
)
logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id}"
)
except Exception as e:
logger.error(f"Error executing graph {args.graph_id}: {e}")
class LateExecutionException(Exception):
pass
def report_late_executions() -> str:
late_executions = execution_utils.get_db_client().get_graph_executions(
statuses=[ExecutionStatus.QUEUED],
created_time_gte=datetime.now(timezone.utc)
- timedelta(seconds=config.execution_late_notification_checkrange_secs),
created_time_lte=datetime.now(timezone.utc)
- timedelta(seconds=config.execution_late_notification_threshold_secs),
limit=1000,
)
if not late_executions:
return "No late executions detected."
num_late_executions = len(late_executions)
num_users = len(set([r.user_id for r in late_executions]))
late_execution_details = [
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
for exec in late_executions
]
error = LateExecutionException(
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
f"in the last {config.execution_late_notification_checkrange_secs} seconds. "
f"Graph has been queued for more than {config.execution_late_notification_threshold_secs} seconds. "
"Please check the executor status. Details:\n"
+ "\n".join(late_execution_details)
)
msg = str(error)
sentry_capture_error(error)
get_notification_client().discord_system_alert(msg)
return msg
def process_existing_batches(**kwargs):
args = NotificationJobArgs(**kwargs)
try:
logger.info(
f"Processing existing batches for notification type {args.notification_types}"
)
get_notification_client().process_existing_batches(args.notification_types)
except Exception as e:
logger.error(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
try:
logger.info("Processing weekly summary")
get_notification_client().queue_weekly_summary()
except Exception as e:
logger.error(f"Error processing weekly summary: {e}")
# Monitoring functions are now imported from monitoring module
class Jobstores(Enum):
@@ -190,11 +131,6 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
)
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
class NotificationJobInfo(NotificationJobArgs):
id: str
name: str
@@ -287,6 +223,16 @@ class Scheduler(AppService):
jobstore=Jobstores.EXECUTION.value,
)
# Block Error Rate Monitoring
self.scheduler.add_job(
report_block_error_rates,
id="report_block_error_rates",
trigger="interval",
replace_existing=True,
seconds=config.block_error_rate_check_interval_secs,
jobstore=Jobstores.EXECUTION.value,
)
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
self.scheduler.start()
@@ -379,6 +325,10 @@ class Scheduler(AppService):
def execute_report_late_executions(self):
return report_late_executions()
@expose
def execute_report_block_error_rates(self):
return report_block_error_rates()
class SchedulerClient(AppServiceClient):
@classmethod

View File

@@ -731,6 +731,7 @@ async def stop_graph_execution(
node_execs = await db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
include_exec_data=False,
)
await db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],

View File

@@ -1,29 +1,226 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from pydantic import BaseModel
from backend.integrations.oauth.todoist import TodoistOAuthHandler
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .linear import LinearOAuthHandler
from .notion import NotionOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
LinearOAuthHandler,
TodoistOAuthHandler,
]
# Build handlers dict with string keys for compatibility with SDK auto-registration
_ORIGINAL_HANDLERS = [
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
TodoistOAuthHandler,
]
# Start with original handlers
_handlers_dict = {
(
handler.PROVIDER_NAME.value
if hasattr(handler.PROVIDER_NAME, "value")
else str(handler.PROVIDER_NAME)
): handler
for handler in _ORIGINAL_HANDLERS
}
class SDKAwareCredentials(BaseModel):
"""OAuth credentials configuration."""
use_secrets: bool = True
client_id_env_var: Optional[str] = None
client_secret_env_var: Optional[str] = None
_credentials_by_provider = {}
# Add default credentials for original handlers
for handler in _ORIGINAL_HANDLERS:
provider_name = (
handler.PROVIDER_NAME.value
if hasattr(handler.PROVIDER_NAME, "value")
else str(handler.PROVIDER_NAME)
)
_credentials_by_provider[provider_name] = SDKAwareCredentials(
use_secrets=True, client_id_env_var=None, client_secret_env_var=None
)
# Create a custom dict class that includes SDK handlers
class SDKAwareHandlersDict(dict):
"""Dictionary that automatically includes SDK-registered OAuth handlers."""
def __getitem__(self, key):
# First try the original handlers
if key in _handlers_dict:
return _handlers_dict[key]
# Then try SDK handlers
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
if key in sdk_handlers:
return sdk_handlers[key]
except ImportError:
pass
# If not found, raise KeyError
raise KeyError(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
if key in _handlers_dict:
return True
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
return key in sdk_handlers
except ImportError:
return False
def keys(self):
# Combine all keys into a single dict and return its keys view
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.keys()
def values(self):
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.values()
def items(self):
combined = dict(_handlers_dict)
try:
from backend.sdk import AutoRegistry
sdk_handlers = AutoRegistry.get_oauth_handlers()
combined.update(sdk_handlers)
except ImportError:
pass
return combined.items()
class SDKAwareCredentialsDict(dict):
"""Dictionary that automatically includes SDK-registered OAuth credentials."""
def __getitem__(self, key):
# First try the original handlers
if key in _credentials_by_provider:
return _credentials_by_provider[key]
# Then try SDK credentials
try:
from backend.sdk import AutoRegistry
sdk_credentials = AutoRegistry.get_oauth_credentials()
if key in sdk_credentials:
# Convert from SDKOAuthCredentials to SDKAwareCredentials
sdk_cred = sdk_credentials[key]
return SDKAwareCredentials(
use_secrets=sdk_cred.use_secrets,
client_id_env_var=sdk_cred.client_id_env_var,
client_secret_env_var=sdk_cred.client_secret_env_var,
)
except ImportError:
pass
# If not found, raise KeyError
raise KeyError(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
if key in _credentials_by_provider:
return True
try:
from backend.sdk import AutoRegistry
sdk_credentials = AutoRegistry.get_oauth_credentials()
return key in sdk_credentials
except ImportError:
return False
def keys(self):
# Combine all keys into a single dict and return its keys view
combined = dict(_credentials_by_provider)
try:
from backend.sdk import AutoRegistry
sdk_credentials = AutoRegistry.get_oauth_credentials()
combined.update(sdk_credentials)
except ImportError:
pass
return combined.keys()
def values(self):
combined = dict(_credentials_by_provider)
try:
from backend.sdk import AutoRegistry
sdk_credentials = AutoRegistry.get_oauth_credentials()
# Convert SDK credentials to SDKAwareCredentials
for key, sdk_cred in sdk_credentials.items():
combined[key] = SDKAwareCredentials(
use_secrets=sdk_cred.use_secrets,
client_id_env_var=sdk_cred.client_id_env_var,
client_secret_env_var=sdk_cred.client_secret_env_var,
)
except ImportError:
pass
return combined.values()
def items(self):
combined = dict(_credentials_by_provider)
try:
from backend.sdk import AutoRegistry
sdk_credentials = AutoRegistry.get_oauth_credentials()
# Convert SDK credentials to SDKAwareCredentials
for key, sdk_cred in sdk_credentials.items():
combined[key] = SDKAwareCredentials(
use_secrets=sdk_cred.use_secrets,
client_id_env_var=sdk_cred.client_id_env_var,
client_secret_env_var=sdk_cred.client_secret_env_var,
)
except ImportError:
pass
return combined.items()
HANDLERS_BY_NAME: dict[str, type["BaseOAuthHandler"]] = SDKAwareHandlersDict()
CREDENTIALS_BY_PROVIDER: dict[str, SDKAwareCredentials] = SDKAwareCredentialsDict()
# --8<-- [end:HANDLERS_BY_NAMEExample]
__all__ = ["HANDLERS_BY_NAME"]

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
# --8<-- [start:BaseOAuthHandler1]
PROVIDER_NAME: ClassVar[ProviderName]
PROVIDER_NAME: ClassVar[ProviderName | str]
DEFAULT_SCOPES: ClassVar[list[str]] = []
# --8<-- [end:BaseOAuthHandler1]
@@ -81,8 +81,6 @@ class BaseOAuthHandler(ABC):
"""Handles the default scopes for the provider"""
# If scopes are empty, use the default scopes for the provider
if not scopes:
logger.debug(
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
)
logger.debug(f"Using default scopes for provider {str(self.PROVIDER_NAME)}")
scopes = self.DEFAULT_SCOPES
return scopes

View File

@@ -1,8 +1,16 @@
from enum import Enum
from typing import Any
# --8<-- [start:ProviderName]
class ProviderName(str, Enum):
"""
Provider names for integrations.
This enum extends str to accept any string value while maintaining
backward compatibility with existing provider constants.
"""
AIML_API = "aiml_api"
ANTHROPIC = "anthropic"
APOLLO = "apollo"
@@ -10,9 +18,7 @@ class ProviderName(str, Enum):
DISCORD = "discord"
D_ID = "d_id"
E2B = "e2b"
EXA = "exa"
FAL = "fal"
GENERIC_WEBHOOK = "generic_webhook"
GITHUB = "github"
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
@@ -21,7 +27,6 @@ class ProviderName(str, Enum):
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"
LINEAR = "linear"
LLAMA_API = "llama_api"
MEDIUM = "medium"
MEM0 = "mem0"
@@ -43,4 +48,57 @@ class ProviderName(str, Enum):
TODOIST = "todoist"
UNREAL_SPEECH = "unreal_speech"
ZEROBOUNCE = "zerobounce"
@classmethod
def _missing_(cls, value: Any) -> "ProviderName":
"""
Allow any string value to be used as a ProviderName.
This enables SDK users to define custom providers without
modifying the enum.
"""
if isinstance(value, str):
# Create a pseudo-member that behaves like an enum member
pseudo_member = str.__new__(cls, value)
pseudo_member._name_ = value.upper()
pseudo_member._value_ = value
return pseudo_member
return None # type: ignore
@classmethod
def __get_pydantic_json_schema__(cls, schema, handler):
"""
Custom JSON schema generation that allows any string value,
not just the predefined enum values.
"""
# Get the default schema
json_schema = handler(schema)
# Remove the enum constraint to allow any string
if "enum" in json_schema:
del json_schema["enum"]
# Keep the type as string
json_schema["type"] = "string"
# Update description to indicate custom providers are allowed
json_schema["description"] = (
"Provider name for integrations. "
"Can be any string value, including custom provider names."
)
return json_schema
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
"""
Pydantic v2 core schema that allows any string value.
"""
from pydantic_core import core_schema
# Create a string schema that validates any string
return core_schema.no_info_after_validator_function(
cls,
core_schema.str_schema(),
)
# --8<-- [end:ProviderName]

View File

@@ -12,7 +12,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
webhook_managers = {}
from .compass import CompassWebhookManager
from .generic import GenericWebhooksManager
from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
@@ -23,7 +22,6 @@ def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]
CompassWebhookManager,
GithubWebhooksManager,
Slant3DWebhooksManager,
GenericWebhooksManager,
]
}
)

View File

@@ -0,0 +1,24 @@
"""Monitoring module for platform health and alerting."""
from .block_error_monitor import BlockErrorMonitor, report_block_error_rates
from .late_execution_monitor import (
LateExecutionException,
LateExecutionMonitor,
report_late_executions,
)
from .notification_monitor import (
NotificationJobArgs,
process_existing_batches,
process_weekly_summary,
)
__all__ = [
"BlockErrorMonitor",
"LateExecutionMonitor",
"LateExecutionException",
"NotificationJobArgs",
"report_block_error_rates",
"report_late_executions",
"process_existing_batches",
"process_weekly_summary",
]

View File

@@ -0,0 +1,291 @@
"""Block error rate monitoring module."""
import logging
import re
from datetime import datetime, timedelta, timezone
from pydantic import BaseModel
from backend.data.block import get_block
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.metrics import sentry_capture_error
from backend.util.service import get_service_client
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
class BlockStatsWithSamples(BaseModel):
"""Enhanced block stats with error samples."""
block_id: str
block_name: str
total_executions: int
failed_executions: int
error_samples: list[str] = []
@property
def error_rate(self) -> float:
"""Calculate error rate as a percentage."""
if self.total_executions == 0:
return 0.0
return (self.failed_executions / self.total_executions) * 100
class BlockErrorMonitor:
"""Monitor block error rates and send alerts when thresholds are exceeded."""
def __init__(self, include_top_blocks: int | None = None):
self.config = config
self.notification_client = get_service_client(NotificationManagerClient)
self.include_top_blocks = (
include_top_blocks
if include_top_blocks is not None
else config.block_error_include_top_blocks
)
def check_block_error_rates(self) -> str:
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
try:
logger.info("Checking block error rates")
# Get executions from the last 24 hours
end_time = datetime.now(timezone.utc)
start_time = end_time - timedelta(hours=24)
# Use SQL aggregation to efficiently count totals and failures by block
block_stats = self._get_block_stats_from_db(start_time, end_time)
# For blocks with high error rates, fetch error samples
threshold = self.config.block_error_rate_threshold
for block_name, stats in block_stats.items():
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
# Only fetch error samples for blocks that exceed threshold
error_samples = self._get_error_samples_for_block(
stats.block_id, start_time, end_time, limit=3
)
stats.error_samples = error_samples
# Check thresholds and send alerts
critical_alerts = self._generate_critical_alerts(block_stats, threshold)
if critical_alerts:
msg = "Block Error Rate Alert:\n\n" + "\n\n".join(critical_alerts)
self.notification_client.discord_system_alert(msg)
logger.info(
f"Sent block error rate alert for {len(critical_alerts)} blocks"
)
return f"Alert sent for {len(critical_alerts)} blocks with high error rates"
# If no critical alerts, check if we should show top blocks
if self.include_top_blocks > 0:
top_blocks_msg = self._generate_top_blocks_alert(
block_stats, start_time, end_time
)
if top_blocks_msg:
self.notification_client.discord_system_alert(top_blocks_msg)
logger.info("Sent top blocks summary")
return "Sent top blocks summary"
logger.info("No blocks exceeded error rate threshold")
return "No errors reported for today"
except Exception as e:
logger.exception(f"Error checking block error rates: {e}")
error = Exception(f"Error checking block error rates: {e}")
msg = str(error)
sentry_capture_error(error)
self.notification_client.discord_system_alert(msg)
return msg
def _get_block_stats_from_db(
self, start_time: datetime, end_time: datetime
) -> dict[str, BlockStatsWithSamples]:
"""Get block execution stats using efficient SQL aggregation."""
result = execution_utils.get_db_client().get_block_error_stats(
start_time, end_time
)
block_stats = {}
for stats in result:
block_name = b.name if (b := get_block(stats.block_id)) else "Unknown"
block_stats[block_name] = BlockStatsWithSamples(
block_id=stats.block_id,
block_name=block_name,
total_executions=stats.total_executions,
failed_executions=stats.failed_executions,
error_samples=[],
)
return block_stats
def _generate_critical_alerts(
self, block_stats: dict[str, BlockStatsWithSamples], threshold: float
) -> list[str]:
"""Generate alerts for blocks that exceed the error rate threshold."""
alerts = []
for block_name, stats in block_stats.items():
if stats.total_executions >= 10 and stats.error_rate >= threshold * 100:
error_groups = self._group_similar_errors(stats.error_samples)
alert_msg = (
f"🚨 Block '{block_name}' has {stats.error_rate:.1f}% error rate "
f"({stats.failed_executions}/{stats.total_executions}) in the last 24 hours"
)
if error_groups:
alert_msg += "\n\n📊 Error Types:"
for error_pattern, count in error_groups.items():
alert_msg += f"\n{error_pattern} ({count}x)"
alerts.append(alert_msg)
return alerts
def _generate_top_blocks_alert(
self,
block_stats: dict[str, BlockStatsWithSamples],
start_time: datetime,
end_time: datetime,
) -> str | None:
"""Generate top blocks summary when no critical alerts exist."""
top_error_blocks = sorted(
[
(name, stats)
for name, stats in block_stats.items()
if stats.total_executions >= 10 and stats.failed_executions > 0
],
key=lambda x: x[1].failed_executions,
reverse=True,
)[: self.include_top_blocks]
if not top_error_blocks:
return "✅ No errors reported for today - all blocks are running smoothly!"
# Get error samples for top blocks
for block_name, stats in top_error_blocks:
if not stats.error_samples:
stats.error_samples = self._get_error_samples_for_block(
stats.block_id, start_time, end_time, limit=2
)
count_text = (
f"top {self.include_top_blocks}" if self.include_top_blocks > 1 else "top"
)
alert_msg = f"📊 Daily Error Summary - {count_text} blocks with most errors:"
for block_name, stats in top_error_blocks:
alert_msg += f"\n{block_name}: {stats.failed_executions} errors ({stats.error_rate:.1f}% of {stats.total_executions})"
if stats.error_samples:
error_groups = self._group_similar_errors(stats.error_samples)
if error_groups:
# Show most common error
most_common_error = next(iter(error_groups.items()))
alert_msg += f"\n └ Most common: {most_common_error[0]}"
return alert_msg
def _get_error_samples_for_block(
self, block_id: str, start_time: datetime, end_time: datetime, limit: int = 3
) -> list[str]:
"""Get error samples for a specific block - just a few recent ones."""
# Only fetch a small number of recent failed executions for this specific block
executions = execution_utils.get_db_client().get_node_executions(
block_ids=[block_id],
statuses=[ExecutionStatus.FAILED],
created_time_gte=start_time,
created_time_lte=end_time,
limit=limit, # Just get the limit we need
)
error_samples = []
for execution in executions:
if error_message := self._extract_error_message(execution):
masked_error = self._mask_sensitive_data(error_message)
error_samples.append(masked_error)
if len(error_samples) >= limit: # Stop once we have enough samples
break
return error_samples
def _extract_error_message(self, execution: NodeExecutionResult) -> str | None:
"""Extract error message from execution output."""
try:
if execution.output_data and (
error_msg := execution.output_data.get("error")
):
return str(error_msg[0])
return None
except Exception:
return None
def _mask_sensitive_data(self, error_message):
"""Mask sensitive data in error messages to enable grouping."""
if not error_message:
return ""
# Convert to string if not already
error_str = str(error_message)
# Mask numbers (replace with X)
error_str = re.sub(r"\d+", "X", error_str)
# Mask all caps words (likely constants/IDs)
error_str = re.sub(r"\b[A-Z_]{3,}\b", "MASKED", error_str)
# Mask words with underscores (likely internal variables)
error_str = re.sub(r"\b\w*_\w*\b", "MASKED", error_str)
# Mask UUIDs and long alphanumeric strings
error_str = re.sub(
r"\b[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\b",
"UUID",
error_str,
)
error_str = re.sub(r"\b[a-f0-9]{20,}\b", "HASH", error_str)
# Mask file paths
error_str = re.sub(r"(/[^/\s]+)+", "/MASKED/path", error_str)
# Mask URLs
error_str = re.sub(r"https?://[^\s]+", "URL", error_str)
# Mask email addresses
error_str = re.sub(
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "EMAIL", error_str
)
# Truncate if too long
if len(error_str) > 100:
error_str = error_str[:97] + "..."
return error_str.strip()
def _group_similar_errors(self, error_samples):
"""Group similar error messages and return counts."""
if not error_samples:
return {}
error_groups = {}
for error in error_samples:
if error in error_groups:
error_groups[error] += 1
else:
error_groups[error] = 1
# Sort by frequency, most common first
return dict(sorted(error_groups.items(), key=lambda x: x[1], reverse=True))
def report_block_error_rates(include_top_blocks: int | None = None):
"""Check block error rates and send Discord alerts if thresholds are exceeded."""
monitor = BlockErrorMonitor(include_top_blocks=include_top_blocks)
return monitor.check_block_error_rates()

View File

@@ -0,0 +1,71 @@
"""Late execution monitoring module."""
import logging
from datetime import datetime, timedelta, timezone
from backend.data.execution import ExecutionStatus
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.metrics import sentry_capture_error
from backend.util.service import get_service_client
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
class LateExecutionException(Exception):
"""Exception raised when late executions are detected."""
pass
class LateExecutionMonitor:
"""Monitor late executions and send alerts when thresholds are exceeded."""
def __init__(self):
self.config = config
self.notification_client = get_service_client(NotificationManagerClient)
def check_late_executions(self) -> str:
"""Check for late executions and send alerts if found."""
late_executions = execution_utils.get_db_client().get_graph_executions(
statuses=[ExecutionStatus.QUEUED],
created_time_gte=datetime.now(timezone.utc)
- timedelta(
seconds=self.config.execution_late_notification_checkrange_secs
),
created_time_lte=datetime.now(timezone.utc)
- timedelta(seconds=self.config.execution_late_notification_threshold_secs),
limit=1000,
)
if not late_executions:
return "No late executions detected."
num_late_executions = len(late_executions)
num_users = len(set([r.user_id for r in late_executions]))
late_execution_details = [
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Created At: {exec.started_at.isoformat()}`"
for exec in late_executions
]
error = LateExecutionException(
f"Late executions detected: {num_late_executions} late executions from {num_users} users "
f"in the last {self.config.execution_late_notification_checkrange_secs} seconds. "
f"Graph has been queued for more than {self.config.execution_late_notification_threshold_secs} seconds. "
"Please check the executor status. Details:\n"
+ "\n".join(late_execution_details)
)
msg = str(error)
sentry_capture_error(error)
self.notification_client.discord_system_alert(msg)
return msg
def report_late_executions() -> str:
"""Check for late executions and send Discord alerts if found."""
monitor = LateExecutionMonitor()
return monitor.check_late_executions()

View File

@@ -0,0 +1,39 @@
"""Notification processing monitoring module."""
import logging
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import get_service_client
logger = logging.getLogger(__name__)
class NotificationJobArgs(BaseModel):
notification_types: list[NotificationType]
cron: str
def process_existing_batches(**kwargs):
"""Process existing notification batches."""
args = NotificationJobArgs(**kwargs)
try:
logging.info(
f"Processing existing batches for notification type {args.notification_types}"
)
get_service_client(NotificationManagerClient).process_existing_batches(
args.notification_types
)
except Exception as e:
logger.exception(f"Error processing existing batches: {e}")
def process_weekly_summary(**kwargs):
"""Process weekly summary notifications."""
try:
logging.info("Processing weekly summary")
get_service_client(NotificationManagerClient).queue_weekly_summary()
except Exception as e:
logger.exception(f"Error processing weekly summary: {e}")

View File

@@ -0,0 +1,169 @@
"""
AutoGPT Platform Block Development SDK
Complete re-export of all dependencies needed for block development.
Usage: from backend.sdk import *
This module provides:
- All block base classes and types
- All credential and authentication components
- All cost tracking components
- All webhook components
- All utility functions
- Auto-registration decorators
"""
# Third-party imports
from pydantic import BaseModel, Field, SecretStr
# === CORE BLOCK SYSTEM ===
from backend.data.block import (
Block,
BlockCategory,
BlockManualWebhookConfig,
BlockOutput,
BlockSchema,
BlockType,
BlockWebhookConfig,
)
from backend.data.integrations import Webhook
from backend.data.model import APIKeyCredentials, Credentials, CredentialsField
from backend.data.model import CredentialsMetaInput as _CredentialsMetaInput
from backend.data.model import (
NodeExecutionStats,
OAuth2Credentials,
SchemaField,
UserPasswordCredentials,
)
# === INTEGRATIONS ===
from backend.integrations.providers import ProviderName
from backend.sdk.builder import ProviderBuilder
from backend.sdk.cost_integration import cost
from backend.sdk.provider import Provider
# === NEW SDK COMPONENTS (imported early for patches) ===
from backend.sdk.registry import AutoRegistry, BlockConfiguration
# === UTILITIES ===
from backend.util import json
from backend.util.request import Requests
# === OPTIONAL IMPORTS WITH TRY/EXCEPT ===
# Webhooks
try:
from backend.integrations.webhooks._base import BaseWebhooksManager
except ImportError:
BaseWebhooksManager = None
try:
from backend.integrations.webhooks._manual_base import ManualWebhookManagerBase
except ImportError:
ManualWebhookManagerBase = None
# Cost System
try:
from backend.data.cost import BlockCost, BlockCostType
except ImportError:
from backend.data.block_cost_config import BlockCost, BlockCostType
try:
from backend.data.credit import UsageTransactionMetadata
except ImportError:
UsageTransactionMetadata = None
try:
from backend.executor.utils import block_usage_cost
except ImportError:
block_usage_cost = None
# Utilities
try:
from backend.util.file import store_media_file
except ImportError:
store_media_file = None
try:
from backend.util.type import MediaFileType, convert
except ImportError:
MediaFileType = None
convert = None
try:
from backend.util.text import TextFormatter
except ImportError:
TextFormatter = None
try:
from backend.util.logging import TruncatedLogger
except ImportError:
TruncatedLogger = None
# OAuth handlers
try:
from backend.integrations.oauth.base import BaseOAuthHandler
except ImportError:
BaseOAuthHandler = None
# Credential type with proper provider name
from typing import Literal as _Literal
CredentialsMetaInput = _CredentialsMetaInput[
ProviderName, _Literal["api_key", "oauth2", "user_password"]
]
# === COMPREHENSIVE __all__ EXPORT ===
__all__ = [
# Core Block System
"Block",
"BlockCategory",
"BlockOutput",
"BlockSchema",
"BlockType",
"BlockWebhookConfig",
"BlockManualWebhookConfig",
# Schema and Model Components
"SchemaField",
"Credentials",
"CredentialsField",
"CredentialsMetaInput",
"APIKeyCredentials",
"OAuth2Credentials",
"UserPasswordCredentials",
"NodeExecutionStats",
# Cost System
"BlockCost",
"BlockCostType",
"UsageTransactionMetadata",
"block_usage_cost",
# Integrations
"ProviderName",
"BaseWebhooksManager",
"ManualWebhookManagerBase",
"Webhook",
# Provider-Specific (when available)
"BaseOAuthHandler",
# Utilities
"json",
"store_media_file",
"MediaFileType",
"convert",
"TextFormatter",
"TruncatedLogger",
"BaseModel",
"Field",
"SecretStr",
"Requests",
# SDK Components
"AutoRegistry",
"BlockConfiguration",
"Provider",
"ProviderBuilder",
"cost",
]
# Remove None values from __all__
__all__ = [name for name in __all__ if globals().get(name) is not None]

View File

@@ -0,0 +1,161 @@
"""
Builder class for creating provider configurations with a fluent API.
"""
import os
from typing import Callable, List, Optional, Type
from pydantic import SecretStr
from backend.data.cost import BlockCost, BlockCostType
from backend.data.model import APIKeyCredentials, Credentials, UserPasswordCredentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
from backend.sdk.provider import OAuthConfig, Provider
from backend.sdk.registry import AutoRegistry
from backend.util.settings import Settings
class ProviderBuilder:
"""Builder for creating provider configurations."""
def __init__(self, name: str):
self.name = name
self._oauth_config: Optional[OAuthConfig] = None
self._webhook_manager: Optional[Type[BaseWebhooksManager]] = None
self._default_credentials: List[Credentials] = []
self._base_costs: List[BlockCost] = []
self._supported_auth_types: set = set()
self._api_client_factory: Optional[Callable] = None
self._error_handler: Optional[Callable[[Exception], str]] = None
self._default_scopes: Optional[List[str]] = None
self._client_id_env_var: Optional[str] = None
self._client_secret_env_var: Optional[str] = None
self._extra_config: dict = {}
def with_oauth(
self,
handler_class: Type[BaseOAuthHandler],
scopes: Optional[List[str]] = None,
client_id_env_var: Optional[str] = None,
client_secret_env_var: Optional[str] = None,
) -> "ProviderBuilder":
"""Add OAuth support."""
self._oauth_config = OAuthConfig(
oauth_handler=handler_class,
scopes=scopes,
client_id_env_var=client_id_env_var,
client_secret_env_var=client_secret_env_var,
)
self._supported_auth_types.add("oauth2")
return self
def with_api_key(self, env_var_name: str, title: str) -> "ProviderBuilder":
"""Add API key support with environment variable name."""
self._supported_auth_types.add("api_key")
# Register the API key mapping
AutoRegistry.register_api_key(self.name, env_var_name)
# Check if API key exists in environment
api_key = os.getenv(env_var_name)
if api_key:
self._default_credentials.append(
APIKeyCredentials(
id=f"{self.name}-default",
provider=self.name,
api_key=SecretStr(api_key),
title=title,
)
)
return self
def with_api_key_from_settings(
self, settings_attr: str, title: str
) -> "ProviderBuilder":
"""Use existing API key from settings."""
self._supported_auth_types.add("api_key")
# Try to get the API key from settings
settings = Settings()
api_key = getattr(settings.secrets, settings_attr, None)
if api_key:
self._default_credentials.append(
APIKeyCredentials(
id=f"{self.name}-default",
provider=self.name,
api_key=api_key,
title=title,
)
)
return self
def with_user_password(
self, username_env_var: str, password_env_var: str, title: str
) -> "ProviderBuilder":
"""Add username/password support with environment variable names."""
self._supported_auth_types.add("user_password")
# Check if credentials exist in environment
username = os.getenv(username_env_var)
password = os.getenv(password_env_var)
if username and password:
self._default_credentials.append(
UserPasswordCredentials(
id=f"{self.name}-default",
provider=self.name,
username=SecretStr(username),
password=SecretStr(password),
title=title,
)
)
return self
def with_webhook_manager(
self, manager_class: Type[BaseWebhooksManager]
) -> "ProviderBuilder":
"""Register webhook manager for this provider."""
self._webhook_manager = manager_class
return self
def with_base_cost(
self, amount: int, cost_type: BlockCostType
) -> "ProviderBuilder":
"""Set base cost for all blocks using this provider."""
self._base_costs.append(BlockCost(cost_amount=amount, cost_type=cost_type))
return self
def with_api_client(self, factory: Callable) -> "ProviderBuilder":
"""Register API client factory."""
self._api_client_factory = factory
return self
def with_error_handler(
self, handler: Callable[[Exception], str]
) -> "ProviderBuilder":
"""Register error handler for provider-specific errors."""
self._error_handler = handler
return self
def with_config(self, **kwargs) -> "ProviderBuilder":
"""Add additional configuration options."""
self._extra_config.update(kwargs)
return self
def build(self) -> Provider:
"""Build and register the provider configuration."""
provider = Provider(
name=self.name,
oauth_config=self._oauth_config,
webhook_manager=self._webhook_manager,
default_credentials=self._default_credentials,
base_costs=self._base_costs,
supported_auth_types=self._supported_auth_types,
api_client_factory=self._api_client_factory,
error_handler=self._error_handler,
**self._extra_config,
)
# Auto-registration happens here
AutoRegistry.register_provider(provider)
return provider

View File

@@ -0,0 +1,163 @@
"""
Integration between SDK provider costs and the execution cost system.
This module provides the glue between provider-defined base costs and the
BLOCK_COSTS configuration used by the execution system.
"""
import logging
from typing import List, Type
from backend.data.block import Block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost
from backend.sdk.registry import AutoRegistry
logger = logging.getLogger(__name__)
def register_provider_costs_for_block(block_class: Type[Block]) -> None:
"""
Register provider base costs for a specific block in BLOCK_COSTS.
This function checks if the block uses credentials from a provider that has
base costs defined, and automatically registers those costs for the block.
Args:
block_class: The block class to register costs for
"""
# Skip if block already has custom costs defined
if block_class in BLOCK_COSTS:
logger.debug(
f"Block {block_class.__name__} already has costs defined, skipping provider costs"
)
return
# Get the block's input schema
# We need to instantiate the block to get its input schema
try:
block_instance = block_class()
input_schema = block_instance.input_schema
except Exception as e:
logger.debug(f"Block {block_class.__name__} cannot be instantiated: {e}")
return
# Look for credentials fields
# The cost system works of filtering on credentials fields,
# without credentials fields, we can not apply costs
# TODO: Improve cost system to allow for costs witout a provider
credentials_fields = input_schema.get_credentials_fields()
if not credentials_fields:
logger.debug(f"Block {block_class.__name__} has no credentials fields")
return
# Get provider information from credentials fields
for field_name, field_info in credentials_fields.items():
# Get the field schema to extract provider information
field_schema = input_schema.get_field_schema(field_name)
# Extract provider names from json_schema_extra
providers = field_schema.get("credentials_provider", [])
if not providers:
continue
# For each provider, check if it has base costs
block_costs: List[BlockCost] = []
for provider_name in providers:
provider = AutoRegistry.get_provider(provider_name)
if not provider:
logger.debug(f"Provider {provider_name} not found in registry")
continue
# Add provider's base costs to the block
if provider.base_costs:
logger.info(
f"Registering {len(provider.base_costs)} base costs from provider {provider_name} for block {block_class.__name__}"
)
block_costs.extend(provider.base_costs)
# Register costs if any were found
if block_costs:
BLOCK_COSTS[block_class] = block_costs
logger.info(
f"Registered {len(block_costs)} total costs for block {block_class.__name__}"
)
def sync_all_provider_costs() -> None:
"""
Sync all provider base costs to blocks that use them.
This should be called after all providers and blocks are registered,
typically during application startup.
"""
from backend.blocks import load_all_blocks
logger.info("Syncing provider costs to blocks...")
blocks_with_costs = 0
total_costs = 0
for block_id, block_class in load_all_blocks().items():
initial_count = len(BLOCK_COSTS.get(block_class, []))
register_provider_costs_for_block(block_class)
final_count = len(BLOCK_COSTS.get(block_class, []))
if final_count > initial_count:
blocks_with_costs += 1
total_costs += final_count - initial_count
logger.info(f"Synced {total_costs} costs to {blocks_with_costs} blocks")
def get_block_costs(block_class: Type[Block]) -> List[BlockCost]:
"""
Get all costs for a block, including both explicit and provider costs.
Args:
block_class: The block class to get costs for
Returns:
List of BlockCost objects for the block
"""
# First ensure provider costs are registered
register_provider_costs_for_block(block_class)
# Return all costs for the block
return BLOCK_COSTS.get(block_class, [])
def cost(*costs: BlockCost):
"""
Decorator to set custom costs for a block.
This decorator allows blocks to define their own costs, which will override
any provider base costs. Multiple costs can be specified with different
filters for different pricing tiers (e.g., different models).
Example:
@cost(
BlockCost(cost_type=BlockCostType.RUN, cost_amount=10),
BlockCost(
cost_type=BlockCostType.RUN,
cost_amount=20,
cost_filter={"model": "premium"}
)
)
class MyBlock(Block):
...
Args:
*costs: Variable number of BlockCost objects
"""
def decorator(block_class: Type[Block]) -> Type[Block]:
# Register the costs for this block
if costs:
BLOCK_COSTS[block_class] = list(costs)
logger.info(
f"Registered {len(costs)} custom costs for block {block_class.__name__}"
)
return block_class
return decorator

View File

@@ -0,0 +1,114 @@
"""
Provider configuration class that holds all provider-related settings.
"""
from typing import Any, Callable, List, Optional, Set, Type
from pydantic import BaseModel
from backend.data.cost import BlockCost
from backend.data.model import Credentials, CredentialsField, CredentialsMetaInput
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
class OAuthConfig(BaseModel):
"""Configuration for OAuth authentication."""
oauth_handler: Type[BaseOAuthHandler]
scopes: Optional[List[str]] = None
client_id_env_var: Optional[str] = None
client_secret_env_var: Optional[str] = None
class Provider:
"""A configured provider that blocks can use.
A Provider represents a service or platform that blocks can integrate with, like Linear, OpenAI, etc.
It contains configuration for:
- Authentication (OAuth, API keys)
- Default credentials
- Base costs for using the provider
- Webhook handling
- Error handling
- API client factory
Blocks use Provider instances to handle authentication, make API calls, and manage service-specific logic.
"""
def __init__(
self,
name: str,
oauth_config: Optional[OAuthConfig] = None,
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
default_credentials: Optional[List[Credentials]] = None,
base_costs: Optional[List[BlockCost]] = None,
supported_auth_types: Optional[Set[str]] = None,
api_client_factory: Optional[Callable] = None,
error_handler: Optional[Callable[[Exception], str]] = None,
**kwargs,
):
self.name = name
self.oauth_config = oauth_config
self.webhook_manager = webhook_manager
self.default_credentials = default_credentials or []
self.base_costs = base_costs or []
self.supported_auth_types = supported_auth_types or set()
self._api_client_factory = api_client_factory
self._error_handler = error_handler
# Store any additional configuration
self._extra_config = kwargs
def credentials_field(self, **kwargs) -> CredentialsMetaInput:
"""Return a CredentialsField configured for this provider."""
# Extract known CredentialsField parameters
title = kwargs.pop("title", None)
description = kwargs.pop("description", f"{self.name.title()} credentials")
required_scopes = kwargs.pop("required_scopes", set())
discriminator = kwargs.pop("discriminator", None)
discriminator_mapping = kwargs.pop("discriminator_mapping", None)
discriminator_values = kwargs.pop("discriminator_values", None)
# Create json_schema_extra with provider information
json_schema_extra = {
"credentials_provider": [self.name],
"credentials_types": (
list(self.supported_auth_types)
if self.supported_auth_types
else ["api_key"]
),
}
# Merge any existing json_schema_extra
if "json_schema_extra" in kwargs:
json_schema_extra.update(kwargs.pop("json_schema_extra"))
# Add json_schema_extra to kwargs
kwargs["json_schema_extra"] = json_schema_extra
return CredentialsField(
required_scopes=required_scopes,
discriminator=discriminator,
discriminator_mapping=discriminator_mapping,
discriminator_values=discriminator_values,
title=title,
description=description,
**kwargs,
)
def get_api(self, credentials: Credentials) -> Any:
"""Get API client instance for the given credentials."""
if self._api_client_factory:
return self._api_client_factory(credentials)
raise NotImplementedError(f"No API client factory registered for {self.name}")
def handle_error(self, error: Exception) -> str:
"""Handle provider-specific errors."""
if self._error_handler:
return self._error_handler(error)
return str(error)
def get_config(self, key: str, default: Any = None) -> Any:
"""Get additional configuration value."""
return self._extra_config.get(key, default)

View File

@@ -0,0 +1,220 @@
"""
Auto-registration system for blocks, providers, and their configurations.
"""
import logging
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from pydantic import BaseModel, SecretStr
from backend.blocks.basic import Block
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.webhooks._base import BaseWebhooksManager
if TYPE_CHECKING:
from backend.sdk.provider import Provider
class SDKOAuthCredentials(BaseModel):
"""OAuth credentials configuration for SDK providers."""
use_secrets: bool = False
client_id_env_var: Optional[str] = None
client_secret_env_var: Optional[str] = None
class BlockConfiguration:
"""Configuration associated with a block."""
def __init__(
self,
provider: str,
costs: List[Any],
default_credentials: List[Credentials],
webhook_manager: Optional[Type[BaseWebhooksManager]] = None,
oauth_handler: Optional[Type[BaseOAuthHandler]] = None,
):
self.provider = provider
self.costs = costs
self.default_credentials = default_credentials
self.webhook_manager = webhook_manager
self.oauth_handler = oauth_handler
class AutoRegistry:
"""Central registry for all block-related configurations."""
_lock = threading.Lock()
_providers: Dict[str, "Provider"] = {}
_default_credentials: List[Credentials] = []
_oauth_handlers: Dict[str, Type[BaseOAuthHandler]] = {}
_oauth_credentials: Dict[str, SDKOAuthCredentials] = {}
_webhook_managers: Dict[str, Type[BaseWebhooksManager]] = {}
_block_configurations: Dict[Type[Block], BlockConfiguration] = {}
_api_key_mappings: Dict[str, str] = {} # provider -> env_var_name
@classmethod
def register_provider(cls, provider: "Provider") -> None:
"""Auto-register provider and all its configurations."""
with cls._lock:
cls._providers[provider.name] = provider
# Register OAuth handler if provided
if provider.oauth_config:
# Dynamically set PROVIDER_NAME if not already set
if (
not hasattr(provider.oauth_config.oauth_handler, "PROVIDER_NAME")
or provider.oauth_config.oauth_handler.PROVIDER_NAME is None
):
# Import ProviderName to create dynamic enum value
from backend.integrations.providers import ProviderName
# This works because ProviderName has _missing_ method
provider.oauth_config.oauth_handler.PROVIDER_NAME = ProviderName(
provider.name
)
cls._oauth_handlers[provider.name] = provider.oauth_config.oauth_handler
# Register OAuth credentials configuration
oauth_creds = SDKOAuthCredentials(
use_secrets=False, # SDK providers use custom env vars
client_id_env_var=provider.oauth_config.client_id_env_var,
client_secret_env_var=provider.oauth_config.client_secret_env_var,
)
cls._oauth_credentials[provider.name] = oauth_creds
# Register webhook manager if provided
if provider.webhook_manager:
# Dynamically set PROVIDER_NAME if not already set
if (
not hasattr(provider.webhook_manager, "PROVIDER_NAME")
or provider.webhook_manager.PROVIDER_NAME is None
):
# Import ProviderName to create dynamic enum value
from backend.integrations.providers import ProviderName
# This works because ProviderName has _missing_ method
provider.webhook_manager.PROVIDER_NAME = ProviderName(provider.name)
cls._webhook_managers[provider.name] = provider.webhook_manager
# Register default credentials
cls._default_credentials.extend(provider.default_credentials)
@classmethod
def register_api_key(cls, provider: str, env_var_name: str) -> None:
"""Register an environment variable as an API key for a provider."""
with cls._lock:
cls._api_key_mappings[provider] = env_var_name
# Dynamically check if the env var exists and create credential
import os
api_key = os.getenv(env_var_name)
if api_key:
credential = APIKeyCredentials(
id=f"{provider}-default",
provider=provider,
api_key=SecretStr(api_key),
title=f"Default {provider} credentials",
)
# Check if credential already exists to avoid duplicates
if not any(c.id == credential.id for c in cls._default_credentials):
cls._default_credentials.append(credential)
@classmethod
def get_all_credentials(cls) -> List[Credentials]:
"""Replace hardcoded get_all_creds() in credentials_store.py."""
with cls._lock:
return cls._default_credentials.copy()
@classmethod
def get_oauth_handlers(cls) -> Dict[str, Type[BaseOAuthHandler]]:
"""Replace HANDLERS_BY_NAME in oauth/__init__.py."""
with cls._lock:
return cls._oauth_handlers.copy()
@classmethod
def get_oauth_credentials(cls) -> Dict[str, SDKOAuthCredentials]:
"""Get OAuth credentials configuration for SDK providers."""
with cls._lock:
return cls._oauth_credentials.copy()
@classmethod
def get_webhook_managers(cls) -> Dict[str, Type[BaseWebhooksManager]]:
"""Replace load_webhook_managers() in webhooks/__init__.py."""
with cls._lock:
return cls._webhook_managers.copy()
@classmethod
def register_block_configuration(
cls, block_class: Type[Block], config: BlockConfiguration
) -> None:
"""Register configuration for a specific block class."""
with cls._lock:
cls._block_configurations[block_class] = config
@classmethod
def get_provider(cls, name: str) -> Optional["Provider"]:
"""Get a registered provider by name."""
with cls._lock:
return cls._providers.get(name)
@classmethod
def get_all_provider_names(cls) -> List[str]:
"""Get all registered provider names."""
with cls._lock:
return list(cls._providers.keys())
@classmethod
def clear(cls) -> None:
"""Clear all registrations (useful for testing)."""
with cls._lock:
cls._providers.clear()
cls._default_credentials.clear()
cls._oauth_handlers.clear()
cls._webhook_managers.clear()
cls._block_configurations.clear()
cls._api_key_mappings.clear()
@classmethod
def patch_integrations(cls) -> None:
"""Patch existing integration points to use AutoRegistry."""
# OAuth handlers are handled by SDKAwareHandlersDict in oauth/__init__.py
# No patching needed for OAuth handlers
# Patch webhook managers
try:
import sys
from typing import Any
# Get the module from sys.modules to respect mocking
if "backend.integrations.webhooks" in sys.modules:
webhooks: Any = sys.modules["backend.integrations.webhooks"]
else:
import backend.integrations.webhooks
webhooks: Any = backend.integrations.webhooks
if hasattr(webhooks, "load_webhook_managers"):
original_load = webhooks.load_webhook_managers
def patched_load():
# Get original managers
managers = original_load()
# Add SDK-registered managers
sdk_managers = cls.get_webhook_managers()
if isinstance(sdk_managers, dict):
# Import ProviderName for conversion
from backend.integrations.providers import ProviderName
# Convert string keys to ProviderName for consistency
for provider_str, manager in sdk_managers.items():
provider_name = ProviderName(provider_str)
managers[provider_name] = manager
return managers
webhooks.load_webhook_managers = patched_load
except Exception as e:
logging.warning(f"Failed to patch webhook managers: {e}")

View File

@@ -1,6 +1,6 @@
import logging
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Optional, Sequence
from typing import Annotated, Any, Optional, Sequence
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
@@ -11,7 +11,6 @@ from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor.utils import add_graph_execution
from backend.server.external.middleware import require_permission
from backend.util.settings import Settings
@@ -30,30 +29,19 @@ class NodeOutput(TypedDict):
class ExecutionNode(TypedDict):
node_id: str
input: Any
output: Dict[str, Any]
output: dict[str, Any]
class ExecutionNodeOutput(TypedDict):
node_id: str
outputs: List[NodeOutput]
outputs: list[NodeOutput]
class GraphExecutionResult(TypedDict):
execution_id: str
status: str
nodes: List[ExecutionNode]
output: Optional[List[Dict[str, str]]]
def get_outputs_with_names(results: list[NodeExecutionResult]) -> list[dict[str, str]]:
outputs = []
for result in results:
if "output" in result.output_data:
output_value = result.output_data["output"][0]
name = result.output_data.get("name", [None])[0]
if output_value and name:
outputs.append({name: output_value})
return outputs
nodes: list[ExecutionNode]
output: Optional[list[dict[str, str]]]
@v1_router.get(
@@ -122,23 +110,34 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_node_executions(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE
graph_exec = await execution_db.get_graph_execution(
user_id=api_key.user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
outputs = get_outputs_with_names(results)
if not graph_exec:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)
return GraphExecutionResult(
execution_id=graph_exec_id,
status=execution_status,
status=graph_exec.status.value,
nodes=[
ExecutionNode(
node_id=result.node_id,
input=result.input_data.get("value", result.input_data),
output={k: v for k, v in result.output_data.items()},
node_id=node_exec.node_id,
input=node_exec.input_data.get("value", node_exec.input_data),
output={k: v for k, v in node_exec.output_data.items()},
)
for result in results
for node_exec in graph_exec.node_executions
],
output=outputs if execution_status == AgentExecutionStatus.COMPLETED else None,
output=(
[
{name: value}
for name, values in graph_exec.outputs.items()
for value in values
]
if graph_exec.status == AgentExecutionStatus.COMPLETED
else None
),
)

View File

@@ -0,0 +1,74 @@
"""
Models for integration-related data structures that need to be exposed in the OpenAPI schema.
This module provides models that will be included in the OpenAPI schema generation,
allowing frontend code generators like Orval to create corresponding TypeScript types.
"""
from pydantic import BaseModel, Field
from backend.integrations.providers import ProviderName
from backend.sdk.registry import AutoRegistry
def get_all_provider_names() -> list[str]:
"""
Collect all provider names from both ProviderName enum and AutoRegistry.
This function should be called at runtime to ensure we get all
dynamically registered providers.
Returns:
A sorted list of unique provider names.
"""
# Get static providers from enum
static_providers = [member.value for member in ProviderName]
# Get dynamic providers from registry
dynamic_providers = AutoRegistry.get_all_provider_names()
# Combine and deduplicate
all_providers = list(set(static_providers + dynamic_providers))
all_providers.sort()
return all_providers
# Note: We don't create a static enum here because providers are registered dynamically.
# Instead, we expose provider names through API endpoints that can be fetched at runtime.
class ProviderNamesResponse(BaseModel):
"""Response containing list of all provider names."""
providers: list[str] = Field(
description="List of all available provider names",
default_factory=get_all_provider_names,
)
class ProviderConstants(BaseModel):
"""
Model that exposes all provider names as a constant in the OpenAPI schema.
This is designed to be converted by Orval into a TypeScript constant.
"""
PROVIDER_NAMES: dict[str, str] = Field(
description="All available provider names as a constant mapping",
default_factory=lambda: {
name.upper().replace("-", "_"): name for name in get_all_provider_names()
},
)
class Config:
schema_extra = {
"example": {
"PROVIDER_NAMES": {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"EXA": "exa",
"GEM": "gem",
"EXAMPLE_SERVICE": "example-service",
}
}
}

View File

@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, List, Literal
from fastapi import (
APIRouter,
@@ -30,9 +30,14 @@ from backend.data.model import (
)
from backend.executor.utils import add_graph_execution
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.server.integrations.models import (
ProviderConstants,
ProviderNamesResponse,
get_all_provider_names,
)
from backend.server.v2.library.db import set_preset_webhook, update_preset
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.settings import Settings
@@ -472,14 +477,49 @@ async def remove_all_webhooks_for_credentials(
def _get_provider_oauth_handler(
req: Request, provider_name: ProviderName
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
# Ensure blocks are loaded so SDK providers are available
try:
from backend.blocks import load_all_blocks
load_all_blocks() # This is cached, so it only runs once
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
# Convert provider_name to string for lookup
provider_key = (
provider_name.value if hasattr(provider_name, "value") else str(provider_name)
)
if provider_key not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name.value}' does not support OAuth",
detail=f"Provider '{provider_key}' does not support OAuth",
)
# Check if this provider has custom OAuth credentials
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_key)
if oauth_credentials and not oauth_credentials.use_secrets:
# SDK provider with custom env vars
import os
client_id = (
os.getenv(oauth_credentials.client_id_env_var)
if oauth_credentials.client_id_env_var
else None
)
client_secret = (
os.getenv(oauth_credentials.client_secret_env_var)
if oauth_credentials.client_secret_env_var
else None
)
else:
# Original provider using settings.secrets
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id", None)
client_secret = getattr(
settings.secrets, f"{provider_name.value}_client_secret", None
)
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
logger.error(
f"Attempt to use unconfigured {provider_name.value} OAuth integration"
@@ -492,14 +532,84 @@ def _get_provider_oauth_handler(
},
)
handler_class = HANDLERS_BY_NAME[provider_name]
frontend_base_url = (
settings.config.frontend_base_url
or settings.config.platform_base_url
or str(req.base_url)
)
handler_class = HANDLERS_BY_NAME[provider_key]
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Frontend base URL is not configured",
)
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
)
# === PROVIDER DISCOVERY ENDPOINTS ===
@router.get("/providers", response_model=List[str])
async def list_providers() -> List[str]:
"""
Get a list of all available provider names.
Returns both statically defined providers (from ProviderName enum)
and dynamically registered providers (from SDK decorators).
Note: The complete list of provider names is also available as a constant
in the generated TypeScript client via PROVIDER_NAMES.
"""
# Get all providers at runtime
all_providers = get_all_provider_names()
return all_providers
@router.get("/providers/names", response_model=ProviderNamesResponse)
async def get_provider_names() -> ProviderNamesResponse:
"""
Get all provider names in a structured format.
This endpoint is specifically designed to expose the provider names
in the OpenAPI schema so that code generators like Orval can create
appropriate TypeScript constants.
"""
return ProviderNamesResponse()
@router.get("/providers/constants", response_model=ProviderConstants)
async def get_provider_constants() -> ProviderConstants:
"""
Get provider names as constants.
This endpoint returns a model with provider names as constants,
specifically designed for OpenAPI code generation tools to create
TypeScript constants.
"""
return ProviderConstants()
class ProviderEnumResponse(BaseModel):
"""Response containing a provider from the enum."""
provider: str = Field(
description="A provider name from the complete list of providers"
)
@router.get("/providers/enum-example", response_model=ProviderEnumResponse)
async def get_provider_enum_example() -> ProviderEnumResponse:
"""
Example endpoint that uses the CompleteProviderNames enum.
This endpoint exists to ensure that the CompleteProviderNames enum is included
in the OpenAPI schema, which will cause Orval to generate it as a
TypeScript enum/constant.
"""
# Return the first provider as an example
all_providers = get_all_provider_names()
return ProviderEnumResponse(
provider=all_providers[0] if all_providers else "openai"
)

View File

@@ -62,6 +62,10 @@ def launch_darkly_context():
async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
# SDK auto-registration is now handled by AutoRegistry.patch_integrations()
# which is called when the SDK module is imported
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)

View File

@@ -448,10 +448,10 @@ class DeleteGraphResponse(TypedDict):
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graphs(
async def list_graphs(
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
) -> Sequence[graph_db.GraphMeta]:
return await graph_db.list_graphs(filter_by="active", user_id=user_id)
@v1_router.get(
@@ -680,22 +680,6 @@ async def stop_graph_run(
return res[0]
@v1_router.post(
path="/executions",
summary="Stop graph executions",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def stop_graph_runs(
graph_id: str, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.GraphExecutionMeta]:
return await _stop_graph_run(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
)
async def _stop_graph_run(
user_id: str,
graph_id: Optional[str] = None,

View File

@@ -270,7 +270,7 @@ def test_get_graphs(
)
mocker.patch(
"backend.server.routers.v1.graph_db.get_graphs",
"backend.server.routers.v1.graph_db.list_graphs",
return_value=[mock_graph],
)

View File

@@ -187,7 +187,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
async def get_library_agent_by_store_version_id(
store_listing_version_id: str,
user_id: str,
):
) -> library_model.LibraryAgent | None:
"""
Get the library agent metadata for a given store listing version ID and user ID.
"""
@@ -202,7 +202,7 @@ async def get_library_agent_by_store_version_id(
)
if not store_listing_version:
logger.warning(f"Store listing version not found: {store_listing_version_id}")
raise store_exceptions.AgentNotFoundError(
raise NotFoundError(
f"Store listing version {store_listing_version_id} not found or invalid"
)
@@ -214,12 +214,9 @@ async def get_library_agent_by_store_version_id(
"agentGraphVersion": store_listing_version.agentGraphVersion,
"isDeleted": False,
},
include={"AgentGraph": True},
include=library_agent_include(user_id),
)
if agent:
return library_model.LibraryAgent.from_db(agent)
else:
return None
return library_model.LibraryAgent.from_db(agent) if agent else None
async def get_library_agent_by_graph_id(

View File

@@ -127,9 +127,9 @@ class LibraryAgent(pydantic.BaseModel):
description=graph.description,
input_schema=graph.input_schema,
credentials_input_schema=(
graph.credentials_input_schema if sub_graphs else None
graph.credentials_input_schema if sub_graphs is not None else None
),
has_external_trigger=graph.has_webhook_trigger,
has_external_trigger=graph.has_external_trigger,
trigger_setup_info=(
LibraryAgentTriggerInfo(
provider=trigger_block.webhook_config.provider,
@@ -262,6 +262,19 @@ class LibraryAgentPresetUpdatable(pydantic.BaseModel):
is_active: Optional[bool] = None
class TriggeredPresetSetupRequest(pydantic.BaseModel):
name: str
description: str = ""
graph_id: str
graph_version: int
trigger_config: dict[str, Any]
agent_credentials: dict[str, CredentialsMetaInput] = pydantic.Field(
default_factory=dict
)
class LibraryAgentPreset(LibraryAgentPresetCreatable):
"""Represents a preset configuration for a library agent."""

View File

@@ -1,18 +1,13 @@
import logging
from typing import Any, Optional
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import Response
from pydantic import BaseModel, Field
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
from backend.data.graph import get_graph
from backend.data.model import CredentialsMetaInput
from backend.executor.utils import make_node_credentials_input_map
from backend.integrations.webhooks.utils import setup_webhook_for_block
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -113,12 +108,11 @@ async def get_library_agent_by_graph_id(
"/marketplace/{store_listing_version_id}",
summary="Get Agent By Store ID",
tags=["store, library"],
response_model=library_model.LibraryAgent | None,
)
async def get_library_agent_by_store_listing_version_id(
store_listing_version_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
):
) -> library_model.LibraryAgent | None:
"""
Get Library Agent from Store Listing Version ID.
"""
@@ -295,81 +289,3 @@ async def fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)
class TriggeredPresetSetupParams(BaseModel):
name: str
description: str = ""
trigger_config: dict[str, Any]
agent_credentials: dict[str, CredentialsMetaInput] = Field(default_factory=dict)
@router.post("/{library_agent_id}/setup-trigger")
async def setup_trigger(
library_agent_id: str = Path(..., description="ID of the library agent"),
params: TriggeredPresetSetupParams = Body(),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgentPreset:
"""
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
"""
library_agent = await library_db.get_library_agent(
id=library_agent_id, user_id=user_id
)
if not library_agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Library agent #{library_agent_id} not found",
)
graph = await get_graph(
library_agent.graph_id, version=library_agent.graph_version, user_id=user_id
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{library_agent.graph_id} not accessible (anymore)",
)
if not (trigger_node := graph.webhook_input_node):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Graph #{library_agent.graph_id} does not have a webhook node",
)
trigger_config_with_credentials = {
**params.trigger_config,
**(
make_node_credentials_input_map(graph, params.agent_credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=trigger_node.block,
trigger_config=trigger_config_with_credentials,
)
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not set up webhook: {feedback}",
)
new_preset = await library_db.create_preset(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
graph_id=library_agent.graph_id,
graph_version=library_agent.graph_version,
name=params.name,
description=params.description,
inputs=trigger_config_with_credentials,
credentials=params.agent_credentials,
webhook_id=new_webhook.id,
is_active=True,
),
)
return new_preset

View File

@@ -138,6 +138,66 @@ async def create_preset(
)
@router.post("/presets/setup-trigger")
async def setup_trigger(
params: models.TriggeredPresetSetupRequest = Body(),
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Sets up a webhook-triggered `LibraryAgentPreset` for a `LibraryAgent`.
Returns the correspondingly created `LibraryAgentPreset` with `webhook_id` set.
"""
graph = await get_graph(
params.graph_id, version=params.graph_version, user_id=user_id
)
if not graph:
raise HTTPException(
status.HTTP_410_GONE,
f"Graph #{params.graph_id} not accessible (anymore)",
)
if not (trigger_node := graph.webhook_input_node):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Graph #{params.graph_id} does not have a webhook node",
)
trigger_config_with_credentials = {
**params.trigger_config,
**(
make_node_credentials_input_map(graph, params.agent_credentials).get(
trigger_node.id
)
or {}
),
}
new_webhook, feedback = await setup_webhook_for_block(
user_id=user_id,
trigger_block=trigger_node.block,
trigger_config=trigger_config_with_credentials,
)
if not new_webhook:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Could not set up webhook: {feedback}",
)
new_preset = await db.create_preset(
user_id=user_id,
preset=models.LibraryAgentPresetCreatable(
graph_id=params.graph_id,
graph_version=params.graph_version,
name=params.name,
description=params.description,
inputs=trigger_config_with_credentials,
credentials=params.agent_credentials,
webhook_id=new_webhook.id,
is_active=True,
),
)
return new_preset
@router.patch(
"/presets/{preset_id}",
summary="Update an existing preset",

View File

@@ -7,10 +7,15 @@ import prisma.errors
import prisma.models
import prisma.types
import backend.data.graph
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
from backend.data.graph import GraphModel, get_sub_graphs
from backend.data.graph import (
GraphMeta,
GraphModel,
get_graph,
get_graph_as_admin,
get_sub_graphs,
)
from backend.data.includes import AGENT_GRAPH_INCLUDE
logger = logging.getLogger(__name__)
@@ -193,9 +198,7 @@ async def get_store_agent_details(
) from e
async def get_available_graph(
store_listing_version_id: str,
):
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
try:
# Get avaialble, non-deleted store listing version
store_listing_version = (
@@ -215,18 +218,7 @@ async def get_available_graph(
detail=f"Store listing version {store_listing_version_id} not found",
)
graph = GraphModel.from_db(store_listing_version.AgentGraph)
# We return graph meta, without nodes, they cannot be just removed
# because then input_schema would be empty
return {
"id": graph.id,
"version": graph.version,
"is_active": graph.is_active,
"name": graph.name,
"description": graph.description,
"input_schema": graph.input_schema,
"output_schema": graph.output_schema,
}
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
except Exception as e:
logger.error(f"Error getting agent: {e}")
@@ -1024,7 +1016,7 @@ async def get_agent(
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph(
graph = await get_graph(
user_id=user_id,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
@@ -1383,7 +1375,7 @@ async def get_agent_as_admin(
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph_as_admin(
graph = await get_graph_as_admin(
user_id=user_id,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,

View File

@@ -124,6 +124,19 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Time in seconds for how far back to check for the late executions.",
)
block_error_rate_threshold: float = Field(
default=0.5,
description="Error rate threshold (0.0-1.0) for triggering block error alerts.",
)
block_error_rate_check_interval_secs: int = Field(
default=24 * 60 * 60, # 24 hours
description="Interval in seconds between block error rate checks.",
)
block_error_include_top_blocks: int = Field(
default=3,
description="Number of top blocks with most errors to show when no blocks exceed threshold (0 to disable).",
)
model_config = SettingsConfigDict(
env_file=".env",
extra="allow",
@@ -263,6 +276,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Whether to mark failed scans as clean or not",
)
enable_example_blocks: bool = Field(
default=False,
description="Whether to enable example blocks in production",
)
@field_validator("platform_base_url", "frontend_base_url")
@classmethod
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:

View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""
Clean the test database by removing all data while preserving the schema.
Usage:
poetry run python clean_test_db.py [--yes]
Options:
--yes Skip confirmation prompt
"""
import asyncio
import sys
from prisma import Prisma
async def main():
db = Prisma()
await db.connect()
print("=" * 60)
print("Cleaning Test Database")
print("=" * 60)
print()
# Get initial counts
user_count = await db.user.count()
agent_count = await db.agentgraph.count()
print(f"Current data: {user_count} users, {agent_count} agent graphs")
if user_count == 0 and agent_count == 0:
print("Database is already clean!")
await db.disconnect()
return
# Check for --yes flag
skip_confirm = "--yes" in sys.argv
if not skip_confirm:
response = input("\nDo you want to clean all data? (yes/no): ")
if response.lower() != "yes":
print("Aborted.")
await db.disconnect()
return
print("\nCleaning database...")
# Delete in reverse order of dependencies
tables = [
("UserNotificationBatch", db.usernotificationbatch),
("NotificationEvent", db.notificationevent),
("CreditRefundRequest", db.creditrefundrequest),
("StoreListingReview", db.storelistingreview),
("StoreListingVersion", db.storelistingversion),
("StoreListing", db.storelisting),
("AgentNodeExecutionInputOutput", db.agentnodeexecutioninputoutput),
("AgentNodeExecution", db.agentnodeexecution),
("AgentGraphExecution", db.agentgraphexecution),
("AgentNodeLink", db.agentnodelink),
("LibraryAgent", db.libraryagent),
("AgentPreset", db.agentpreset),
("IntegrationWebhook", db.integrationwebhook),
("AgentNode", db.agentnode),
("AgentGraph", db.agentgraph),
("AgentBlock", db.agentblock),
("APIKey", db.apikey),
("CreditTransaction", db.credittransaction),
("AnalyticsMetrics", db.analyticsmetrics),
("AnalyticsDetails", db.analyticsdetails),
("Profile", db.profile),
("UserOnboarding", db.useronboarding),
("User", db.user),
]
for table_name, table in tables:
try:
count = await table.count()
if count > 0:
await table.delete_many()
print(f"✓ Deleted {count} records from {table_name}")
except Exception as e:
print(f"⚠ Error cleaning {table_name}: {e}")
# Refresh materialized views (they should be empty now)
try:
await db.execute_raw("SELECT refresh_store_materialized_views();")
print("\n✓ Refreshed materialized views")
except Exception as e:
print(f"\n⚠ Could not refresh materialized views: {e}")
await db.disconnect()
print("\n" + "=" * 60)
print("Database cleaned successfully!")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,35 +1,60 @@
networks:
app-network:
name: app-network
shared-network:
name: shared-network
volumes:
supabase-config:
x-agpt-services:
&agpt-services
networks:
- app-network
- shared-network
x-supabase-services:
&supabase-services
networks:
- app-network
- shared-network
volumes:
clamav-data:
services:
postgres-test:
image: ankane/pgvector:latest
environment:
- POSTGRES_USER=${DB_USER:-postgres}
- POSTGRES_PASSWORD=${DB_PASS:-postgres}
- POSTGRES_DB=${DB_NAME:-postgres}
- POSTGRES_PORT=${DB_PORT:-5432}
healthcheck:
test: pg_isready -U $$POSTGRES_USER -d $$POSTGRES_DB
interval: 10s
timeout: 5s
retries: 5
db:
<<: *supabase-services
extends:
file: ../db/docker/docker-compose.yml
service: db
ports:
- "${DB_PORT:-5432}:5432"
networks:
- app-network-test
redis-test:
- ${POSTGRES_PORT}:5432 # We don't use Supavisor locally, so we expose the db directly.
vector:
<<: *supabase-services
extends:
file: ../db/docker/docker-compose.yml
service: vector
redis:
<<: *agpt-services
image: redis:latest
command: redis-server --requirepass password
ports:
- "6379:6379"
networks:
- app-network-test
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
rabbitmq-test:
rabbitmq:
<<: *agpt-services
image: rabbitmq:management
container_name: rabbitmq-test
container_name: rabbitmq
healthcheck:
test: rabbitmq-diagnostics -q ping
interval: 30s
@@ -38,11 +63,28 @@ services:
start_period: 10s
environment:
- RABBITMQ_DEFAULT_USER=rabbitmq_user_default
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7 # CHANGE THIS TO A RANDOM PASSWORD IN PRODUCTION -- everywhere lol
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ports:
- "5672:5672"
- "15672:15672"
clamav:
image: clamav/clamav-debian:latest
ports:
- "3310:3310"
volumes:
- clamav-data:/var/lib/clamav
environment:
- CLAMAV_NO_FRESHCLAMD=false
- CLAMD_CONF_StreamMaxLength=50M
- CLAMD_CONF_MaxFileSize=100M
- CLAMD_CONF_MaxScanSize=100M
- CLAMD_CONF_MaxThreads=12
- CLAMD_CONF_ReadTimeout=300
healthcheck:
test: ["CMD-SHELL", "clamdscan --version || exit 1"]
interval: 30s
timeout: 10s
retries: 3
networks:
app-network-test:
driver: bridge

View File

@@ -0,0 +1,254 @@
-- This migration creates materialized views for performance optimization
--
-- IMPORTANT: For production environments, pg_cron is REQUIRED for automatic refresh
-- Prerequisites for production:
-- 1. pg_cron extension must be installed: CREATE EXTENSION pg_cron;
-- 2. pg_cron must be configured in postgresql.conf:
-- shared_preload_libraries = 'pg_cron'
-- cron.database_name = 'your_database_name'
--
-- For development environments without pg_cron:
-- The migration will succeed but you must manually refresh views with:
-- SELECT refresh_store_materialized_views();
-- Check if pg_cron extension is installed and set a flag
DO $$
DECLARE
has_pg_cron BOOLEAN;
BEGIN
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
IF NOT has_pg_cron THEN
RAISE WARNING 'pg_cron extension is not installed!';
RAISE WARNING 'Materialized views will be created but WILL NOT refresh automatically.';
RAISE WARNING 'For production use, install pg_cron with: CREATE EXTENSION pg_cron;';
RAISE WARNING 'For development, manually refresh with: SELECT refresh_store_materialized_views();';
-- For production deployments, uncomment the following line to make pg_cron mandatory:
-- RAISE EXCEPTION 'pg_cron is required for production deployments';
END IF;
-- Store the flag for later use in the migration
PERFORM set_config('migration.has_pg_cron', has_pg_cron::text, false);
END
$$;
-- CreateIndex
-- Optimized: Only include owningUserId in index columns since isDeleted and hasApprovedVersion are in WHERE clause
CREATE INDEX IF NOT EXISTS "idx_store_listing_approved" ON "StoreListing"("owningUserId") WHERE "isDeleted" = false AND "hasApprovedVersion" = true;
-- CreateIndex
-- Optimized: Only include storeListingId since submissionStatus is in WHERE clause
CREATE INDEX IF NOT EXISTS "idx_store_listing_version_status" ON "StoreListingVersion"("storeListingId") WHERE "submissionStatus" = 'APPROVED';
-- CreateIndex
CREATE INDEX IF NOT EXISTS "idx_slv_categories_gin" ON "StoreListingVersion" USING GIN ("categories") WHERE "submissionStatus" = 'APPROVED';
-- CreateIndex
CREATE INDEX IF NOT EXISTS "idx_slv_agent" ON "StoreListingVersion"("agentGraphId", "agentGraphVersion") WHERE "submissionStatus" = 'APPROVED';
-- CreateIndex
CREATE INDEX IF NOT EXISTS "idx_store_listing_review_version" ON "StoreListingReview"("storeListingVersionId");
-- CreateIndex
CREATE INDEX IF NOT EXISTS "idx_agent_graph_execution_agent" ON "AgentGraphExecution"("agentGraphId");
-- CreateIndex
CREATE INDEX IF NOT EXISTS "idx_profile_user" ON "Profile"("userId");
-- Additional performance indexes
CREATE INDEX IF NOT EXISTS "idx_store_listing_version_approved_listing" ON "StoreListingVersion"("storeListingId", "version") WHERE "submissionStatus" = 'APPROVED';
-- Create materialized view for agent run counts
CREATE MATERIALIZED VIEW IF NOT EXISTS "mv_agent_run_counts" AS
SELECT
"agentGraphId",
COUNT(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "agentGraphId";
-- CreateIndex
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_agent_run_counts" ON "mv_agent_run_counts"("agentGraphId");
-- Create materialized view for review statistics
CREATE MATERIALIZED VIEW IF NOT EXISTS "mv_review_stats" AS
SELECT
sl.id AS "storeListingId",
COUNT(sr.id) AS review_count,
AVG(sr.score::numeric) AS avg_rating
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
WHERE sl."isDeleted" = false
AND slv."submissionStatus" = 'APPROVED'
GROUP BY sl.id;
-- CreateIndex
CREATE UNIQUE INDEX IF NOT EXISTS "idx_mv_review_stats" ON "mv_review_stats"("storeListingId");
-- DropForeignKey (if any exist on the views)
-- None needed as views don't have foreign keys
-- DropView
DROP VIEW IF EXISTS "Creator";
-- DropView
DROP VIEW IF EXISTS "StoreAgent";
-- CreateView
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username,
p."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions
FROM "StoreListing" sl
INNER JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
-- CreateView
CREATE OR REPLACE VIEW "Creator" AS
WITH creator_listings AS (
SELECT
sl."owningUserId",
sl.id AS listing_id,
slv."agentGraphId",
slv.categories,
sr.score,
ar.run_count
FROM "StoreListing" sl
INNER JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
AND slv."submissionStatus" = 'APPROVED'
LEFT JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv.id
LEFT JOIN "mv_agent_run_counts" ar
ON ar."agentGraphId" = slv."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
),
creator_stats AS (
SELECT
cl."owningUserId",
COUNT(DISTINCT cl.listing_id) AS num_agents,
AVG(COALESCE(cl.score, 0)::numeric) AS agent_rating,
SUM(DISTINCT COALESCE(cl.run_count, 0)) AS agent_runs,
array_agg(DISTINCT cat ORDER BY cat) FILTER (WHERE cat IS NOT NULL) AS all_categories
FROM creator_listings cl
LEFT JOIN LATERAL unnest(COALESCE(cl.categories, ARRAY[]::text[])) AS cat ON true
GROUP BY cl."owningUserId"
)
SELECT
p.username,
p.name,
p."avatarUrl" AS avatar_url,
p.description,
cs.all_categories AS top_categories,
p.links,
p."isFeatured" AS is_featured,
COALESCE(cs.num_agents, 0::bigint) AS num_agents,
COALESCE(cs.agent_rating, 0.0) AS agent_rating,
COALESCE(cs.agent_runs, 0::numeric) AS agent_runs
FROM "Profile" p
LEFT JOIN creator_stats cs ON cs."owningUserId" = p."userId";
-- Create refresh function that works with the current schema
CREATE OR REPLACE FUNCTION refresh_store_materialized_views()
RETURNS void
LANGUAGE plpgsql
AS $$
DECLARE
current_schema_name text;
BEGIN
-- Get the current schema
current_schema_name := current_schema();
-- Use CONCURRENTLY for better performance during refresh
EXECUTE format('REFRESH MATERIALIZED VIEW CONCURRENTLY %I."mv_agent_run_counts"', current_schema_name);
EXECUTE format('REFRESH MATERIALIZED VIEW CONCURRENTLY %I."mv_review_stats"', current_schema_name);
RAISE NOTICE 'Materialized views refreshed in schema % at %', current_schema_name, NOW();
EXCEPTION
WHEN OTHERS THEN
-- Fallback to non-concurrent refresh if concurrent fails
EXECUTE format('REFRESH MATERIALIZED VIEW %I."mv_agent_run_counts"', current_schema_name);
EXECUTE format('REFRESH MATERIALIZED VIEW %I."mv_review_stats"', current_schema_name);
RAISE NOTICE 'Materialized views refreshed (non-concurrent) in schema % at % due to: %', current_schema_name, NOW(), SQLERRM;
END;
$$;
-- Initial refresh of materialized views
SELECT refresh_store_materialized_views();
-- Schedule automatic refresh every 15 minutes (only if pg_cron is available)
DO $$
DECLARE
has_pg_cron BOOLEAN;
current_schema_name text;
job_name text;
BEGIN
-- Get the flag we set earlier
has_pg_cron := current_setting('migration.has_pg_cron', true)::boolean;
-- Get current schema name
current_schema_name := current_schema();
-- Create a unique job name for this schema
job_name := format('refresh-store-views-%s', current_schema_name);
IF has_pg_cron THEN
-- Try to unschedule existing job (ignore errors if it doesn't exist)
BEGIN
PERFORM cron.unschedule(job_name);
EXCEPTION WHEN OTHERS THEN
-- Job doesn't exist, that's fine
NULL;
END;
-- Schedule the refresh job with schema-specific command
PERFORM cron.schedule(
job_name,
'*/15 * * * *',
format('SELECT %I.refresh_store_materialized_views();', current_schema_name)
);
RAISE NOTICE 'Scheduled automatic refresh of materialized views every 15 minutes for schema %', current_schema_name;
ELSE
RAISE WARNING '⚠️ Automatic refresh NOT configured - pg_cron is not available';
RAISE WARNING '⚠️ You must manually refresh views with: SELECT refresh_store_materialized_views();';
RAISE WARNING '⚠️ Or install pg_cron for automatic refresh in production';
END IF;
END;
$$;

View File

@@ -0,0 +1,155 @@
-- Unschedule cron job (if it exists)
DO $$
BEGIN
IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') THEN
PERFORM cron.unschedule('refresh-store-views');
RAISE NOTICE 'Unscheduled automatic refresh of materialized views';
END IF;
EXCEPTION
WHEN OTHERS THEN
RAISE NOTICE 'Could not unschedule cron job (may not exist): %', SQLERRM;
END;
$$;
-- DropView
DROP VIEW IF EXISTS "Creator";
-- DropView
DROP VIEW IF EXISTS "StoreAgent";
-- CreateView (restore original StoreAgent)
CREATE VIEW "StoreAgent" AS
WITH reviewstats AS (
SELECT sl_1.id AS "storeListingId",
count(sr.id) AS review_count,
avg(sr.score::numeric) AS avg_rating
FROM "StoreListing" sl_1
JOIN "StoreListingVersion" slv_1
ON slv_1."storeListingId" = sl_1.id
JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv_1.id
WHERE sl_1."isDeleted" = false
GROUP BY sl_1.id
), agentruns AS (
SELECT "AgentGraphExecution"."agentGraphId",
count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
)
SELECT sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username,
p."avatarUrl" AS creator_avatar,
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
array_agg(DISTINCT slv.version::text) AS versions
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN reviewstats rs
ON sl.id = rs."storeListingId"
LEFT JOIN agentruns ar
ON a.id = ar."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
GROUP BY sl.id, slv.id, sl.slug, slv."createdAt", slv.name, slv."videoUrl",
slv."imageUrls", slv."isFeatured", p.username, p."avatarUrl",
slv."subHeading", slv.description, slv.categories, ar.run_count,
rs.avg_rating;
-- CreateView (restore original Creator)
CREATE VIEW "Creator" AS
WITH agentstats AS (
SELECT p_1.username,
count(DISTINCT sl.id) AS num_agents,
avg(COALESCE(sr.score, 0)::numeric) AS agent_rating,
sum(COALESCE(age.run_count, 0::bigint)) AS agent_runs
FROM "Profile" p_1
LEFT JOIN "StoreListing" sl
ON sl."owningUserId" = p_1."userId"
LEFT JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
LEFT JOIN "StoreListingReview" sr
ON sr."storeListingVersionId" = slv.id
LEFT JOIN (
SELECT "AgentGraphExecution"."agentGraphId",
count(*) AS run_count
FROM "AgentGraphExecution"
GROUP BY "AgentGraphExecution"."agentGraphId"
) age ON age."agentGraphId" = slv."agentGraphId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
GROUP BY p_1.username
)
SELECT p.username,
p.name,
p."avatarUrl" AS avatar_url,
p.description,
array_agg(DISTINCT cats.c) FILTER (WHERE cats.c IS NOT NULL) AS top_categories,
p.links,
p."isFeatured" AS is_featured,
COALESCE(ast.num_agents, 0::bigint) AS num_agents,
COALESCE(ast.agent_rating, 0.0) AS agent_rating,
COALESCE(ast.agent_runs, 0::numeric) AS agent_runs
FROM "Profile" p
LEFT JOIN agentstats ast
ON ast.username = p.username
LEFT JOIN LATERAL (
SELECT unnest(slv.categories) AS c
FROM "StoreListing" sl
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = sl.id
WHERE sl."owningUserId" = p."userId"
AND sl."isDeleted" = false
AND sl."hasApprovedVersion" = true
AND slv."submissionStatus" = 'APPROVED'
) cats ON true
GROUP BY p.username, p.name, p."avatarUrl", p.description, p.links,
p."isFeatured", ast.num_agents, ast.agent_rating, ast.agent_runs;
-- Drop function
DROP FUNCTION IF EXISTS platform.refresh_store_materialized_views();
-- Drop materialized views
DROP MATERIALIZED VIEW IF EXISTS "mv_review_stats";
DROP MATERIALIZED VIEW IF EXISTS "mv_agent_run_counts";
-- DropIndex
DROP INDEX IF EXISTS "idx_profile_user";
-- DropIndex
DROP INDEX IF EXISTS "idx_agent_graph_execution_agent";
-- DropIndex
DROP INDEX IF EXISTS "idx_store_listing_review_version";
-- DropIndex
DROP INDEX IF EXISTS "idx_slv_agent";
-- DropIndex
DROP INDEX IF EXISTS "idx_slv_categories_gin";
-- DropIndex
DROP INDEX IF EXISTS "idx_store_listing_version_status";
-- DropIndex
DROP INDEX IF EXISTS "idx_store_listing_approved";
-- DropIndex
DROP INDEX IF EXISTS "idx_store_listing_version_approved_listing";

View File

@@ -123,3 +123,4 @@ filterwarnings = [
[tool.ruff]
target-version = "py310"

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""
Run test data creation and update scripts in sequence.
Usage:
poetry run python run_test_data.py
"""
import asyncio
import subprocess
import sys
from pathlib import Path
def run_command(cmd: list[str], cwd: Path | None = None) -> bool:
"""Run a command and return True if successful."""
try:
result = subprocess.run(
cmd, check=True, capture_output=True, text=True, cwd=cwd
)
if result.stdout:
print(result.stdout)
return True
except subprocess.CalledProcessError as e:
print(f"Error running command: {' '.join(cmd)}")
print(f"Error: {e.stderr}")
return False
async def main():
"""Main function to run test data scripts."""
print("=" * 60)
print("Running Test Data Scripts for AutoGPT Platform")
print("=" * 60)
print()
# Get the backend directory
backend_dir = Path(__file__).parent
test_dir = backend_dir / "test"
# Check if we're in the right directory
if not (backend_dir / "pyproject.toml").exists():
print("ERROR: This script must be run from the backend directory")
sys.exit(1)
print("1. Checking database connection...")
print("-" * 40)
# Import here to ensure proper environment setup
try:
from prisma import Prisma
db = Prisma()
await db.connect()
print("✓ Database connection successful")
await db.disconnect()
except Exception as e:
print(f"✗ Database connection failed: {e}")
print("\nPlease ensure:")
print("1. The database services are running (docker compose up -d)")
print("2. The DATABASE_URL in .env is correct")
print("3. Migrations have been run (poetry run prisma migrate deploy)")
sys.exit(1)
print()
print("2. Running test data creator...")
print("-" * 40)
# Run test_data_creator.py
if run_command(["poetry", "run", "python", "test_data_creator.py"], cwd=test_dir):
print()
print("✅ Test data created successfully!")
print()
print("3. Running test data updater...")
print("-" * 40)
# Run test_data_updater.py
if run_command(
["poetry", "run", "python", "test_data_updater.py"], cwd=test_dir
):
print()
print("✅ Test data updated successfully!")
else:
print()
print("❌ Test data updater failed!")
sys.exit(1)
else:
print()
print("❌ Test data creator failed!")
sys.exit(1)
print()
print("=" * 60)
print("Test data setup completed successfully!")
print("=" * 60)
print()
print("The materialized views have been populated with test data:")
print("- mv_agent_run_counts: Agent execution statistics")
print("- mv_review_stats: Store listing review statistics")
print()
print("You can now:")
print("1. Run tests: poetry run test")
print("2. Start the backend: poetry run serve")
print("3. View data in the database")
print()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -13,8 +13,10 @@ def wait_for_postgres(max_retries=5, delay=5):
"compose",
"-f",
"docker-compose.test.yaml",
"--env-file",
"../.env",
"exec",
"postgres-test",
"db",
"pg_isready",
"-U",
"postgres",
@@ -51,6 +53,8 @@ def test():
"compose",
"-f",
"docker-compose.test.yaml",
"--env-file",
"../.env",
"up",
"-d",
]
@@ -74,11 +78,20 @@ def test():
# to their development database, running tests would wipe their local data!
test_env = os.environ.copy()
# Use environment variables if set, otherwise use defaults that match docker-compose.test.yaml
db_user = os.getenv("DB_USER", "postgres")
db_pass = os.getenv("DB_PASS", "postgres")
db_name = os.getenv("DB_NAME", "postgres")
db_port = os.getenv("DB_PORT", "5432")
# Load database configuration from .env file
dotenv_path = os.path.join(os.path.dirname(__file__), "../.env")
if os.path.exists(dotenv_path):
with open(dotenv_path) as f:
for line in f:
if line.strip() and not line.startswith("#"):
key, value = line.strip().split("=", 1)
os.environ[key] = value
# Get database config from environment (now populated from .env)
db_user = os.getenv("POSTGRES_USER", "postgres")
db_pass = os.getenv("POSTGRES_PASSWORD", "postgres")
db_name = os.getenv("POSTGRES_DB", "postgres")
db_port = os.getenv("POSTGRES_PORT", "5432")
# Construct the test database URL - this ensures we're always pointing to the test container
test_env["DATABASE_URL"] = (

View File

@@ -599,7 +599,23 @@ view Creator {
agent_runs Int
is_featured Boolean
// Index or unique are not applied to views
// Note: Prisma doesn't support indexes on views, but the following indexes exist in the database:
//
// Optimized indexes (partial indexes to reduce size and improve performance):
// - idx_profile_user on Profile(userId)
// - idx_store_listing_approved on StoreListing(owningUserId) WHERE isDeleted = false AND hasApprovedVersion = true
// - idx_store_listing_version_status on StoreListingVersion(storeListingId) WHERE submissionStatus = 'APPROVED'
// - idx_slv_categories_gin - GIN index on StoreListingVersion(categories) WHERE submissionStatus = 'APPROVED'
// - idx_slv_agent on StoreListingVersion(agentGraphId, agentGraphVersion) WHERE submissionStatus = 'APPROVED'
// - idx_store_listing_review_version on StoreListingReview(storeListingVersionId)
// - idx_store_listing_version_approved_listing on StoreListingVersion(storeListingId, version) WHERE submissionStatus = 'APPROVED'
// - idx_agent_graph_execution_agent on AgentGraphExecution(agentGraphId)
//
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
// - mv_review_stats - Pre-aggregated review statistics (count, avg rating) by storeListingId
//
// Query strategy: Uses CTEs to efficiently aggregate creator statistics leveraging materialized views
}
view StoreAgent {
@@ -622,7 +638,30 @@ view StoreAgent {
rating Float
versions String[]
// Index or unique are not applied to views
// Note: Prisma doesn't support indexes on views, but the following indexes exist in the database:
//
// Optimized indexes (partial indexes to reduce size and improve performance):
// - idx_store_listing_approved on StoreListing(owningUserId) WHERE isDeleted = false AND hasApprovedVersion = true
// - idx_store_listing_version_status on StoreListingVersion(storeListingId) WHERE submissionStatus = 'APPROVED'
// - idx_slv_categories_gin - GIN index on StoreListingVersion(categories) WHERE submissionStatus = 'APPROVED' for array searches
// - idx_slv_agent on StoreListingVersion(agentGraphId, agentGraphVersion) WHERE submissionStatus = 'APPROVED'
// - idx_store_listing_review_version on StoreListingReview(storeListingVersionId)
// - idx_store_listing_version_approved_listing on StoreListingVersion(storeListingId, version) WHERE submissionStatus = 'APPROVED'
// - idx_agent_graph_execution_agent on AgentGraphExecution(agentGraphId)
// - idx_profile_user on Profile(userId)
//
// Additional indexes from earlier migrations:
// - StoreListing_agentId_owningUserId_idx
// - StoreListing_isDeleted_isApproved_idx (replaced by idx_store_listing_approved)
// - StoreListing_isDeleted_idx
// - StoreListing_agentId_key (unique on agentGraphId)
// - StoreListingVersion_agentId_agentVersion_isDeleted_idx
//
// Materialized views used (refreshed every 15 minutes via pg_cron):
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
// - mv_review_stats - Pre-aggregated review statistics (count, avg rating) by storeListingId
//
// Query strategy: Uses CTE for version aggregation and joins with materialized views for performance
}
view StoreSubmission {
@@ -649,6 +688,33 @@ view StoreSubmission {
// Index or unique are not applied to views
}
// Note: This is actually a MATERIALIZED VIEW in the database
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
view mv_agent_run_counts {
agentGraphId String @unique
run_count Int
// Pre-aggregated count of AgentGraphExecution records by agentGraphId
// Used by StoreAgent and Creator views for performance optimization
// Unique index created automatically on agentGraphId for fast lookups
// Refresh uses CONCURRENTLY to avoid blocking reads
}
// Note: This is actually a MATERIALIZED VIEW in the database
// Refreshed automatically every 15 minutes via pg_cron (with fallback to manual refresh)
view mv_review_stats {
storeListingId String @unique
review_count Int
avg_rating Float
// Pre-aggregated review statistics from StoreListingReview
// Includes count of reviews and average rating per StoreListing
// Only includes approved versions (submissionStatus = 'APPROVED') and non-deleted listings
// Used by StoreAgent view for performance optimization
// Unique index created automatically on storeListingId for fast lookups
// Refresh uses CONCURRENTLY to avoid blocking reads
}
model StoreListing {
id String @id @default(uuid())
createdAt DateTime @default(now())

View File

@@ -7,7 +7,7 @@
"description": "A test graph",
"forked_from_id": null,
"forked_from_version": null,
"has_webhook_trigger": false,
"has_external_trigger": false,
"id": "graph-123",
"input_schema": {
"properties": {},

View File

@@ -8,7 +8,7 @@
"description": "A test graph",
"forked_from_id": null,
"forked_from_version": null,
"has_webhook_trigger": false,
"has_external_trigger": false,
"id": "graph-123",
"input_schema": {
"properties": {},
@@ -16,9 +16,7 @@
"type": "object"
},
"is_active": true,
"links": [],
"name": "Test Graph",
"nodes": [],
"output_schema": {
"properties": {},
"required": [],

View File

@@ -0,0 +1 @@
"""SDK test module."""

View File

@@ -0,0 +1,20 @@
"""
Shared configuration for SDK test providers using the SDK pattern.
"""
from backend.sdk import BlockCostType, ProviderBuilder
# Configure test providers
test_api = (
ProviderBuilder("test_api")
.with_api_key("TEST_API_KEY", "Test API Key")
.with_base_cost(5, BlockCostType.RUN)
.build()
)
test_service = (
ProviderBuilder("test_service")
.with_api_key("TEST_SERVICE_API_KEY", "Test Service API Key")
.with_base_cost(10, BlockCostType.RUN)
.build()
)

View File

@@ -0,0 +1,29 @@
"""
Configuration for SDK tests.
This conftest.py file provides basic test setup for SDK unit tests
without requiring the full server infrastructure.
"""
from unittest.mock import MagicMock
import pytest
@pytest.fixture(scope="session")
def server():
"""Mock server fixture for SDK tests."""
mock_server = MagicMock()
mock_server.agent_server = MagicMock()
mock_server.agent_server.test_create_graph = MagicMock()
return mock_server
@pytest.fixture(autouse=True)
def reset_registry():
"""Reset the AutoRegistry before each test."""
from backend.sdk.registry import AutoRegistry
AutoRegistry.clear()
yield
AutoRegistry.clear()

View File

@@ -0,0 +1,914 @@
"""
Tests for creating blocks using the SDK.
This test suite verifies that blocks can be created using only SDK imports
and that they work correctly without decorators.
"""
from typing import Any, Optional, Union
import pytest
from backend.sdk import (
APIKeyCredentials,
Block,
BlockCategory,
BlockCostType,
BlockOutput,
BlockSchema,
CredentialsMetaInput,
OAuth2Credentials,
ProviderBuilder,
SchemaField,
SecretStr,
)
from ._config import test_api, test_service
class TestBasicBlockCreation:
"""Test creating basic blocks using the SDK."""
@pytest.mark.asyncio
async def test_simple_block(self):
"""Test creating a simple block without any decorators."""
class SimpleBlock(Block):
"""A simple test block."""
class Input(BlockSchema):
text: str = SchemaField(description="Input text")
count: int = SchemaField(description="Repeat count", default=1)
class Output(BlockSchema):
result: str = SchemaField(description="Output result")
def __init__(self):
super().__init__(
id="simple-test-block",
description="A simple test block",
categories={BlockCategory.TEXT},
input_schema=SimpleBlock.Input,
output_schema=SimpleBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
result = input_data.text * input_data.count
yield "result", result
# Create and test the block
block = SimpleBlock()
assert block.id == "simple-test-block"
assert BlockCategory.TEXT in block.categories
# Test execution
outputs = []
async for name, value in block.run(
SimpleBlock.Input(text="Hello ", count=3),
):
outputs.append((name, value))
assert len(outputs) == 1
assert outputs[0] == ("result", "Hello Hello Hello ")
@pytest.mark.asyncio
async def test_block_with_credentials(self):
"""Test creating a block that requires credentials."""
class APIBlock(Block):
"""A block that requires API credentials."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = test_api.credentials_field(
description="API credentials for test service",
)
query: str = SchemaField(description="API query")
class Output(BlockSchema):
response: str = SchemaField(description="API response")
authenticated: bool = SchemaField(description="Was authenticated")
def __init__(self):
super().__init__(
id="api-test-block",
description="Test block with API credentials",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=APIBlock.Input,
output_schema=APIBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Simulate API call
api_key = credentials.api_key.get_secret_value()
authenticated = bool(api_key)
yield "response", f"API response for: {input_data.query}"
yield "authenticated", authenticated
# Create test credentials
test_creds = APIKeyCredentials(
id="test-creds",
provider="test_api",
api_key=SecretStr("test-api-key"),
title="Test API Key",
)
# Create and test the block
block = APIBlock()
outputs = []
async for name, value in block.run(
APIBlock.Input(
credentials={ # type: ignore
"provider": "test_api",
"id": "test-creds",
"type": "api_key",
},
query="test query",
),
credentials=test_creds,
):
outputs.append((name, value))
assert len(outputs) == 2
assert outputs[0] == ("response", "API response for: test query")
assert outputs[1] == ("authenticated", True)
@pytest.mark.asyncio
async def test_block_with_multiple_outputs(self):
"""Test block that yields multiple outputs."""
class MultiOutputBlock(Block):
"""Block with multiple outputs."""
class Input(BlockSchema):
text: str = SchemaField(description="Input text")
class Output(BlockSchema):
uppercase: str = SchemaField(description="Uppercase version")
lowercase: str = SchemaField(description="Lowercase version")
length: int = SchemaField(description="Text length")
is_empty: bool = SchemaField(description="Is text empty")
def __init__(self):
super().__init__(
id="multi-output-block",
description="Block with multiple outputs",
categories={BlockCategory.TEXT},
input_schema=MultiOutputBlock.Input,
output_schema=MultiOutputBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
text = input_data.text
yield "uppercase", text.upper()
yield "lowercase", text.lower()
yield "length", len(text)
yield "is_empty", len(text) == 0
# Test the block
block = MultiOutputBlock()
outputs = []
async for name, value in block.run(MultiOutputBlock.Input(text="Hello World")):
outputs.append((name, value))
assert len(outputs) == 4
assert ("uppercase", "HELLO WORLD") in outputs
assert ("lowercase", "hello world") in outputs
assert ("length", 11) in outputs
assert ("is_empty", False) in outputs
class TestBlockWithProvider:
"""Test creating blocks associated with providers."""
@pytest.mark.asyncio
async def test_block_using_provider(self):
"""Test block that uses a registered provider."""
class TestServiceBlock(Block):
"""Block for test service."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = test_service.credentials_field(
description="Test service credentials",
)
action: str = SchemaField(description="Action to perform")
class Output(BlockSchema):
result: str = SchemaField(description="Action result")
provider_name: str = SchemaField(description="Provider used")
def __init__(self):
super().__init__(
id="test-service-block",
description="Block using test service provider",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=TestServiceBlock.Input,
output_schema=TestServiceBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# The provider name should match
yield "result", f"Performed: {input_data.action}"
yield "provider_name", credentials.provider
# Create credentials for our provider
creds = APIKeyCredentials(
id="test-service-creds",
provider="test_service",
api_key=SecretStr("test-key"),
title="Test Service Key",
)
# Test the block
block = TestServiceBlock()
outputs = {}
async for name, value in block.run(
TestServiceBlock.Input(
credentials={ # type: ignore
"provider": "test_service",
"id": "test-service-creds",
"type": "api_key",
},
action="test action",
),
credentials=creds,
):
outputs[name] = value
assert outputs["result"] == "Performed: test action"
assert outputs["provider_name"] == "test_service"
class TestComplexBlockScenarios:
"""Test more complex block scenarios."""
@pytest.mark.asyncio
async def test_block_with_optional_fields(self):
"""Test block with optional input fields."""
# Optional is already imported at the module level
class OptionalFieldBlock(Block):
"""Block with optional fields."""
class Input(BlockSchema):
required_field: str = SchemaField(description="Required field")
optional_field: Optional[str] = SchemaField(
description="Optional field",
default=None,
)
optional_with_default: str = SchemaField(
description="Optional with default",
default="default value",
)
class Output(BlockSchema):
has_optional: bool = SchemaField(description="Has optional value")
optional_value: Optional[str] = SchemaField(
description="Optional value"
)
default_value: str = SchemaField(description="Default value")
def __init__(self):
super().__init__(
id="optional-field-block",
description="Block with optional fields",
categories={BlockCategory.TEXT},
input_schema=OptionalFieldBlock.Input,
output_schema=OptionalFieldBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "has_optional", input_data.optional_field is not None
yield "optional_value", input_data.optional_field
yield "default_value", input_data.optional_with_default
# Test with optional field provided
block = OptionalFieldBlock()
outputs = {}
async for name, value in block.run(
OptionalFieldBlock.Input(
required_field="test",
optional_field="provided",
)
):
outputs[name] = value
assert outputs["has_optional"] is True
assert outputs["optional_value"] == "provided"
assert outputs["default_value"] == "default value"
# Test without optional field
outputs = {}
async for name, value in block.run(
OptionalFieldBlock.Input(
required_field="test",
)
):
outputs[name] = value
assert outputs["has_optional"] is False
assert outputs["optional_value"] is None
assert outputs["default_value"] == "default value"
@pytest.mark.asyncio
async def test_block_with_complex_types(self):
"""Test block with complex input/output types."""
class ComplexBlock(Block):
"""Block with complex types."""
class Input(BlockSchema):
items: list[str] = SchemaField(description="List of items")
mapping: dict[str, int] = SchemaField(
description="String to int mapping"
)
class Output(BlockSchema):
item_count: int = SchemaField(description="Number of items")
total_value: int = SchemaField(description="Sum of mapping values")
combined: list[str] = SchemaField(description="Combined results")
def __init__(self):
super().__init__(
id="complex-types-block",
description="Block with complex types",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ComplexBlock.Input,
output_schema=ComplexBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "item_count", len(input_data.items)
yield "total_value", sum(input_data.mapping.values())
# Combine items with their mapping values
combined = []
for item in input_data.items:
value = input_data.mapping.get(item, 0)
combined.append(f"{item}: {value}")
yield "combined", combined
# Test the block
block = ComplexBlock()
outputs = {}
async for name, value in block.run(
ComplexBlock.Input(
items=["apple", "banana", "orange"],
mapping={"apple": 5, "banana": 3, "orange": 4},
)
):
outputs[name] = value
assert outputs["item_count"] == 3
assert outputs["total_value"] == 12
assert outputs["combined"] == ["apple: 5", "banana: 3", "orange: 4"]
@pytest.mark.asyncio
async def test_block_error_handling(self):
"""Test block error handling."""
class ErrorHandlingBlock(Block):
"""Block that demonstrates error handling."""
class Input(BlockSchema):
value: int = SchemaField(description="Input value")
should_error: bool = SchemaField(
description="Whether to trigger an error",
default=False,
)
class Output(BlockSchema):
result: int = SchemaField(description="Result")
error_message: Optional[str] = SchemaField(
description="Error if any", default=None
)
def __init__(self):
super().__init__(
id="error-handling-block",
description="Block with error handling",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ErrorHandlingBlock.Input,
output_schema=ErrorHandlingBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
if input_data.should_error:
raise ValueError("Intentional error triggered")
if input_data.value < 0:
yield "error_message", "Value must be non-negative"
yield "result", 0
else:
yield "result", input_data.value * 2
yield "error_message", None
# Test normal operation
block = ErrorHandlingBlock()
outputs = {}
async for name, value in block.run(
ErrorHandlingBlock.Input(value=5, should_error=False)
):
outputs[name] = value
assert outputs["result"] == 10
assert outputs["error_message"] is None
# Test with negative value
outputs = {}
async for name, value in block.run(
ErrorHandlingBlock.Input(value=-5, should_error=False)
):
outputs[name] = value
assert outputs["result"] == 0
assert outputs["error_message"] == "Value must be non-negative"
# Test with error
with pytest.raises(ValueError, match="Intentional error triggered"):
async for _ in block.run(
ErrorHandlingBlock.Input(value=5, should_error=True)
):
pass
class TestAuthenticationVariants:
"""Test complex authentication scenarios including OAuth, API keys, and scopes."""
@pytest.mark.asyncio
async def test_oauth_block_with_scopes(self):
"""Test creating a block that uses OAuth2 with scopes."""
from backend.sdk import OAuth2Credentials, ProviderBuilder
# Create a test OAuth provider with scopes
# For testing, we don't need an actual OAuth handler
# In real usage, you would provide a proper OAuth handler class
oauth_provider = (
ProviderBuilder("test_oauth_provider")
.with_api_key("TEST_OAUTH_API", "Test OAuth API")
.with_base_cost(5, BlockCostType.RUN)
.build()
)
class OAuthScopedBlock(Block):
"""Block requiring OAuth2 with specific scopes."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = oauth_provider.credentials_field(
description="OAuth2 credentials with scopes",
scopes=["read:user", "write:data"],
)
resource: str = SchemaField(description="Resource to access")
class Output(BlockSchema):
data: str = SchemaField(description="Retrieved data")
scopes_used: list[str] = SchemaField(
description="Scopes that were used"
)
token_info: dict[str, Any] = SchemaField(
description="Token information"
)
def __init__(self):
super().__init__(
id="oauth-scoped-block",
description="Test OAuth2 with scopes",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=OAuthScopedBlock.Input,
output_schema=OAuthScopedBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
) -> BlockOutput:
# Simulate OAuth API call with scopes
token = credentials.access_token.get_secret_value()
yield "data", f"OAuth data for {input_data.resource}"
yield "scopes_used", credentials.scopes or []
yield "token_info", {
"has_token": bool(token),
"has_refresh": credentials.refresh_token is not None,
"provider": credentials.provider,
"expires_at": credentials.access_token_expires_at,
}
# Create test OAuth credentials
test_oauth_creds = OAuth2Credentials(
id="test-oauth-creds",
provider="test_oauth_provider",
access_token=SecretStr("test-access-token"),
refresh_token=SecretStr("test-refresh-token"),
scopes=["read:user", "write:data"],
title="Test OAuth Credentials",
)
# Test the block
block = OAuthScopedBlock()
outputs = {}
async for name, value in block.run(
OAuthScopedBlock.Input(
credentials={ # type: ignore
"provider": "test_oauth_provider",
"id": "test-oauth-creds",
"type": "oauth2",
},
resource="user/profile",
),
credentials=test_oauth_creds,
):
outputs[name] = value
assert outputs["data"] == "OAuth data for user/profile"
assert set(outputs["scopes_used"]) == {"read:user", "write:data"}
assert outputs["token_info"]["has_token"] is True
assert outputs["token_info"]["expires_at"] is None
assert outputs["token_info"]["has_refresh"] is True
@pytest.mark.asyncio
async def test_mixed_auth_block(self):
"""Test block that supports both OAuth2 and API key authentication."""
# No need to import these again, already imported at top
# Create provider supporting both auth types
# Create provider supporting API key auth
# In real usage, you would add OAuth support with .with_oauth()
mixed_provider = (
ProviderBuilder("mixed_auth_provider")
.with_api_key("MIXED_API_KEY", "Mixed Provider API Key")
.with_base_cost(8, BlockCostType.RUN)
.build()
)
class MixedAuthBlock(Block):
"""Block supporting multiple authentication methods."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = mixed_provider.credentials_field(
description="API key or OAuth2 credentials",
supported_credential_types=["api_key", "oauth2"],
)
operation: str = SchemaField(description="Operation to perform")
class Output(BlockSchema):
result: str = SchemaField(description="Operation result")
auth_type: str = SchemaField(description="Authentication type used")
auth_details: dict[str, Any] = SchemaField(description="Auth details")
def __init__(self):
super().__init__(
id="mixed-auth-block",
description="Block supporting OAuth2 and API key",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MixedAuthBlock.Input,
output_schema=MixedAuthBlock.Output,
)
async def run(
self,
input_data: Input,
*,
credentials: Union[APIKeyCredentials, OAuth2Credentials],
**kwargs,
) -> BlockOutput:
# Handle different credential types
if isinstance(credentials, APIKeyCredentials):
auth_type = "api_key"
auth_details = {
"has_key": bool(credentials.api_key.get_secret_value()),
"key_prefix": credentials.api_key.get_secret_value()[:5]
+ "...",
}
elif isinstance(credentials, OAuth2Credentials):
auth_type = "oauth2"
auth_details = {
"has_token": bool(credentials.access_token.get_secret_value()),
"scopes": credentials.scopes or [],
}
else:
auth_type = "unknown"
auth_details = {}
yield "result", f"Performed {input_data.operation} with {auth_type}"
yield "auth_type", auth_type
yield "auth_details", auth_details
# Test with API key
api_creds = APIKeyCredentials(
id="mixed-api-creds",
provider="mixed_auth_provider",
api_key=SecretStr("sk-1234567890"),
title="Mixed API Key",
)
block = MixedAuthBlock()
outputs = {}
async for name, value in block.run(
MixedAuthBlock.Input(
credentials={ # type: ignore
"provider": "mixed_auth_provider",
"id": "mixed-api-creds",
"type": "api_key",
},
operation="fetch_data",
),
credentials=api_creds,
):
outputs[name] = value
assert outputs["auth_type"] == "api_key"
assert outputs["result"] == "Performed fetch_data with api_key"
assert outputs["auth_details"]["key_prefix"] == "sk-12..."
# Test with OAuth2
oauth_creds = OAuth2Credentials(
id="mixed-oauth-creds",
provider="mixed_auth_provider",
access_token=SecretStr("oauth-token-123"),
scopes=["full_access"],
title="Mixed OAuth",
)
outputs = {}
async for name, value in block.run(
MixedAuthBlock.Input(
credentials={ # type: ignore
"provider": "mixed_auth_provider",
"id": "mixed-oauth-creds",
"type": "oauth2",
},
operation="update_data",
),
credentials=oauth_creds,
):
outputs[name] = value
assert outputs["auth_type"] == "oauth2"
assert outputs["result"] == "Performed update_data with oauth2"
assert outputs["auth_details"]["scopes"] == ["full_access"]
@pytest.mark.asyncio
async def test_multiple_credentials_block(self):
"""Test block requiring multiple different credentials."""
from backend.sdk import ProviderBuilder
# Create multiple providers
primary_provider = (
ProviderBuilder("primary_service")
.with_api_key("PRIMARY_API_KEY", "Primary Service Key")
.build()
)
# For testing purposes, using API key instead of OAuth handler
secondary_provider = (
ProviderBuilder("secondary_service")
.with_api_key("SECONDARY_API_KEY", "Secondary Service Key")
.build()
)
class MultiCredentialBlock(Block):
"""Block requiring credentials from multiple services."""
class Input(BlockSchema):
primary_credentials: CredentialsMetaInput = (
primary_provider.credentials_field(
description="Primary service API key"
)
)
secondary_credentials: CredentialsMetaInput = (
secondary_provider.credentials_field(
description="Secondary service OAuth"
)
)
merge_data: bool = SchemaField(
description="Whether to merge data from both services",
default=True,
)
class Output(BlockSchema):
primary_data: str = SchemaField(description="Data from primary service")
secondary_data: str = SchemaField(
description="Data from secondary service"
)
merged_result: Optional[str] = SchemaField(
description="Merged data if requested"
)
def __init__(self):
super().__init__(
id="multi-credential-block",
description="Block using multiple credentials",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MultiCredentialBlock.Input,
output_schema=MultiCredentialBlock.Output,
)
async def run(
self,
input_data: Input,
*,
primary_credentials: APIKeyCredentials,
secondary_credentials: OAuth2Credentials,
**kwargs,
) -> BlockOutput:
# Simulate fetching data with primary API key
primary_data = f"Primary data using {primary_credentials.provider}"
yield "primary_data", primary_data
# Simulate fetching data with secondary OAuth
secondary_data = f"Secondary data with {len(secondary_credentials.scopes or [])} scopes"
yield "secondary_data", secondary_data
# Merge if requested
if input_data.merge_data:
merged = f"{primary_data} + {secondary_data}"
yield "merged_result", merged
else:
yield "merged_result", None
# Create test credentials
primary_creds = APIKeyCredentials(
id="primary-creds",
provider="primary_service",
api_key=SecretStr("primary-key-123"),
title="Primary Key",
)
secondary_creds = OAuth2Credentials(
id="secondary-creds",
provider="secondary_service",
access_token=SecretStr("secondary-token"),
scopes=["read", "write"],
title="Secondary OAuth",
)
# Test the block
block = MultiCredentialBlock()
outputs = {}
# Note: In real usage, the framework would inject the correct credentials
# based on the field names. Here we simulate that behavior.
async for name, value in block.run(
MultiCredentialBlock.Input(
primary_credentials={ # type: ignore
"provider": "primary_service",
"id": "primary-creds",
"type": "api_key",
},
secondary_credentials={ # type: ignore
"provider": "secondary_service",
"id": "secondary-creds",
"type": "oauth2",
},
merge_data=True,
),
primary_credentials=primary_creds,
secondary_credentials=secondary_creds,
):
outputs[name] = value
assert outputs["primary_data"] == "Primary data using primary_service"
assert outputs["secondary_data"] == "Secondary data with 2 scopes"
assert "Primary data" in outputs["merged_result"]
assert "Secondary data" in outputs["merged_result"]
@pytest.mark.asyncio
async def test_oauth_scope_validation(self):
"""Test OAuth scope validation and handling."""
from backend.sdk import OAuth2Credentials, ProviderBuilder
# Provider with specific required scopes
# For testing OAuth scope validation
scoped_provider = (
ProviderBuilder("scoped_oauth_service")
.with_api_key("SCOPED_OAUTH_KEY", "Scoped OAuth Service")
.build()
)
class ScopeValidationBlock(Block):
"""Block that validates OAuth scopes."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = scoped_provider.credentials_field(
description="OAuth credentials with specific scopes",
scopes=["user:read", "user:write"], # Required scopes
)
require_admin: bool = SchemaField(
description="Whether admin scopes are required",
default=False,
)
class Output(BlockSchema):
allowed_operations: list[str] = SchemaField(
description="Operations allowed with current scopes"
)
missing_scopes: list[str] = SchemaField(
description="Scopes that are missing for full access"
)
has_required_scopes: bool = SchemaField(
description="Whether all required scopes are present"
)
def __init__(self):
super().__init__(
id="scope-validation-block",
description="Block that validates OAuth scopes",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=ScopeValidationBlock.Input,
output_schema=ScopeValidationBlock.Output,
)
async def run(
self, input_data: Input, *, credentials: OAuth2Credentials, **kwargs
) -> BlockOutput:
current_scopes = set(credentials.scopes or [])
required_scopes = {"user:read", "user:write"}
if input_data.require_admin:
required_scopes.update({"admin:read", "admin:write"})
# Determine allowed operations based on scopes
allowed_ops = []
if "user:read" in current_scopes:
allowed_ops.append("read_user_data")
if "user:write" in current_scopes:
allowed_ops.append("update_user_data")
if "admin:read" in current_scopes:
allowed_ops.append("read_admin_data")
if "admin:write" in current_scopes:
allowed_ops.append("update_admin_data")
missing = list(required_scopes - current_scopes)
has_required = len(missing) == 0
yield "allowed_operations", allowed_ops
yield "missing_scopes", missing
yield "has_required_scopes", has_required
# Test with partial scopes
partial_creds = OAuth2Credentials(
id="partial-oauth",
provider="scoped_oauth_service",
access_token=SecretStr("partial-token"),
scopes=["user:read"], # Only one of the required scopes
title="Partial OAuth",
)
block = ScopeValidationBlock()
outputs = {}
async for name, value in block.run(
ScopeValidationBlock.Input(
credentials={ # type: ignore
"provider": "scoped_oauth_service",
"id": "partial-oauth",
"type": "oauth2",
},
require_admin=False,
),
credentials=partial_creds,
):
outputs[name] = value
assert outputs["allowed_operations"] == ["read_user_data"]
assert "user:write" in outputs["missing_scopes"]
assert outputs["has_required_scopes"] is False
# Test with all required scopes
full_creds = OAuth2Credentials(
id="full-oauth",
provider="scoped_oauth_service",
access_token=SecretStr("full-token"),
scopes=["user:read", "user:write", "admin:read"],
title="Full OAuth",
)
outputs = {}
async for name, value in block.run(
ScopeValidationBlock.Input(
credentials={ # type: ignore
"provider": "scoped_oauth_service",
"id": "full-oauth",
"type": "oauth2",
},
require_admin=False,
),
credentials=full_creds,
):
outputs[name] = value
assert set(outputs["allowed_operations"]) == {
"read_user_data",
"update_user_data",
"read_admin_data",
}
assert outputs["missing_scopes"] == []
assert outputs["has_required_scopes"] is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,150 @@
"""
Tests for the SDK's integration patching mechanism.
This test suite verifies that the AutoRegistry correctly patches
existing integration points to include SDK-registered components.
"""
from unittest.mock import MagicMock, Mock, patch
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
AutoRegistry,
BaseOAuthHandler,
BaseWebhooksManager,
ProviderBuilder,
)
class MockOAuthHandler(BaseOAuthHandler):
"""Mock OAuth handler for testing."""
PROVIDER_NAME = ProviderName.GITHUB
@classmethod
async def authorize(cls, *args, **kwargs):
return "mock_auth"
class MockWebhookManager(BaseWebhooksManager):
"""Mock webhook manager for testing."""
PROVIDER_NAME = ProviderName.GITHUB
@classmethod
async def validate_payload(cls, webhook, request):
return {}, "test_event"
async def _register_webhook(self, *args, **kwargs):
return "mock_webhook_id", {}
async def _deregister_webhook(self, *args, **kwargs):
pass
class TestWebhookPatching:
"""Test webhook manager patching functionality."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_webhook_manager_patching(self):
"""Test that webhook managers are correctly patched."""
# Mock the original load_webhook_managers function
def mock_load_webhook_managers():
return {
"existing_webhook": Mock(spec=BaseWebhooksManager),
}
# Register a provider with webhooks
(
ProviderBuilder("webhook_provider")
.with_webhook_manager(MockWebhookManager)
.build()
)
# Mock the webhooks module
mock_webhooks_module = MagicMock()
mock_webhooks_module.load_webhook_managers = mock_load_webhook_managers
with patch.dict(
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
):
AutoRegistry.patch_integrations()
# Call the patched function
result = mock_webhooks_module.load_webhook_managers()
# Original webhook should still exist
assert "existing_webhook" in result
# New webhook should be added
assert "webhook_provider" in result
assert result["webhook_provider"] == MockWebhookManager
def test_webhook_patching_no_original_function(self):
"""Test webhook patching when load_webhook_managers doesn't exist."""
# Mock webhooks module without load_webhook_managers
mock_webhooks_module = MagicMock(spec=[])
# Register a provider
(
ProviderBuilder("test_provider")
.with_webhook_manager(MockWebhookManager)
.build()
)
with patch.dict(
"sys.modules", {"backend.integrations.webhooks": mock_webhooks_module}
):
# Should not raise an error
AutoRegistry.patch_integrations()
# Function should not be added if it didn't exist
assert not hasattr(mock_webhooks_module, "load_webhook_managers")
class TestPatchingIntegration:
"""Test the complete patching integration flow."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_complete_provider_registration_and_patching(self):
"""Test the complete flow from provider registration to patching."""
# Mock webhooks module
mock_webhooks = MagicMock()
mock_webhooks.load_webhook_managers = lambda: {"original": Mock()}
# Create a fully featured provider
(
ProviderBuilder("complete_provider")
.with_api_key("COMPLETE_KEY", "Complete API Key")
.with_oauth(MockOAuthHandler, scopes=["read", "write"])
.with_webhook_manager(MockWebhookManager)
.build()
)
# Apply patches
with patch.dict(
"sys.modules",
{
"backend.integrations.webhooks": mock_webhooks,
},
):
AutoRegistry.patch_integrations()
# Verify webhook patching
webhook_result = mock_webhooks.load_webhook_managers()
assert "complete_provider" in webhook_result
assert webhook_result["complete_provider"] == MockWebhookManager
assert "original" in webhook_result # Original preserved
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,482 @@
"""
Tests for the SDK auto-registration system via AutoRegistry.
This test suite verifies:
1. Provider registration and retrieval
2. OAuth handler registration via patches
3. Webhook manager registration via patches
4. Credential registration and management
5. Block configuration association
"""
from unittest.mock import MagicMock, Mock, patch
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
AutoRegistry,
BaseOAuthHandler,
BaseWebhooksManager,
Block,
BlockConfiguration,
Provider,
ProviderBuilder,
)
class TestAutoRegistry:
"""Test the AutoRegistry functionality."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
def test_provider_registration(self):
"""Test that providers can be registered and retrieved."""
# Create a test provider
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
# Register it
AutoRegistry.register_provider(provider)
# Verify it's registered
assert "test_provider" in AutoRegistry._providers
assert AutoRegistry.get_provider("test_provider") == provider
def test_provider_with_oauth(self):
"""Test provider registration with OAuth handler."""
# Create a mock OAuth handler
class TestOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
from backend.sdk.provider import OAuthConfig
provider = Provider(
name="oauth_provider",
oauth_config=OAuthConfig(oauth_handler=TestOAuthHandler),
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
AutoRegistry.register_provider(provider)
# Verify OAuth handler is registered
assert "oauth_provider" in AutoRegistry._oauth_handlers
assert AutoRegistry._oauth_handlers["oauth_provider"] == TestOAuthHandler
def test_provider_with_webhook_manager(self):
"""Test provider registration with webhook manager."""
# Create a mock webhook manager
class TestWebhookManager(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
provider = Provider(
name="webhook_provider",
oauth_handler=None,
webhook_manager=TestWebhookManager,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Verify webhook manager is registered
assert "webhook_provider" in AutoRegistry._webhook_managers
assert AutoRegistry._webhook_managers["webhook_provider"] == TestWebhookManager
def test_default_credentials_registration(self):
"""Test that default credentials are registered."""
# Create test credentials
from backend.sdk import SecretStr
cred1 = APIKeyCredentials(
id="test-cred-1",
provider="test_provider",
api_key=SecretStr("test-key-1"),
title="Test Credential 1",
)
cred2 = APIKeyCredentials(
id="test-cred-2",
provider="test_provider",
api_key=SecretStr("test-key-2"),
title="Test Credential 2",
)
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[cred1, cred2],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Verify credentials are registered
all_creds = AutoRegistry.get_all_credentials()
assert cred1 in all_creds
assert cred2 in all_creds
def test_api_key_registration(self):
"""Test API key environment variable registration."""
import os
# Set up a test environment variable
os.environ["TEST_API_KEY"] = "test-api-key-value"
try:
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
# Verify the mapping is stored
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"
# Verify a credential was created
all_creds = AutoRegistry.get_all_credentials()
test_cred = next(
(c for c in all_creds if c.id == "test_provider-default"), None
)
assert test_cred is not None
assert test_cred.provider == "test_provider"
assert test_cred.api_key.get_secret_value() == "test-api-key-value" # type: ignore
finally:
# Clean up
del os.environ["TEST_API_KEY"]
def test_get_oauth_handlers(self):
"""Test retrieving all OAuth handlers."""
# Register multiple providers with OAuth
class TestOAuth1(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
class TestOAuth2(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GOOGLE
from backend.sdk.provider import OAuthConfig
provider1 = Provider(
name="provider1",
oauth_config=OAuthConfig(oauth_handler=TestOAuth1),
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
provider2 = Provider(
name="provider2",
oauth_config=OAuthConfig(oauth_handler=TestOAuth2),
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"oauth2"},
)
AutoRegistry.register_provider(provider1)
AutoRegistry.register_provider(provider2)
handlers = AutoRegistry.get_oauth_handlers()
assert "provider1" in handlers
assert "provider2" in handlers
assert handlers["provider1"] == TestOAuth1
assert handlers["provider2"] == TestOAuth2
def test_block_configuration_registration(self):
"""Test registering block configuration."""
# Create a test block class
class TestBlock(Block):
pass
config = BlockConfiguration(
provider="test_provider",
costs=[],
default_credentials=[],
webhook_manager=None,
oauth_handler=None,
)
AutoRegistry.register_block_configuration(TestBlock, config)
# Verify it's registered
assert TestBlock in AutoRegistry._block_configurations
assert AutoRegistry._block_configurations[TestBlock] == config
def test_clear_registry(self):
"""Test clearing all registrations."""
# Add some registrations
provider = Provider(
name="test_provider",
oauth_handler=None,
webhook_manager=None,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
AutoRegistry.register_api_key("test", "TEST_KEY")
# Clear everything
AutoRegistry.clear()
# Verify everything is cleared
assert len(AutoRegistry._providers) == 0
assert len(AutoRegistry._default_credentials) == 0
assert len(AutoRegistry._oauth_handlers) == 0
assert len(AutoRegistry._webhook_managers) == 0
assert len(AutoRegistry._block_configurations) == 0
assert len(AutoRegistry._api_key_mappings) == 0
class TestAutoRegistryPatching:
"""Test the integration patching functionality."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
@patch("backend.integrations.webhooks.load_webhook_managers")
def test_webhook_manager_patching(self, mock_load_managers):
"""Test that webhook managers are patched into the system."""
# Set up the mock to return an empty dict
mock_load_managers.return_value = {}
# Create a test webhook manager
class TestWebhookManager(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
# Register a provider with webhooks
provider = Provider(
name="webhook_provider",
oauth_handler=None,
webhook_manager=TestWebhookManager,
default_credentials=[],
base_costs=[],
supported_auth_types={"api_key"},
)
AutoRegistry.register_provider(provider)
# Mock the webhooks module
mock_webhooks = MagicMock()
mock_webhooks.load_webhook_managers = mock_load_managers
with patch.dict(
"sys.modules", {"backend.integrations.webhooks": mock_webhooks}
):
# Apply patches
AutoRegistry.patch_integrations()
# Call the patched function
result = mock_webhooks.load_webhook_managers()
# Verify our webhook manager is included
assert "webhook_provider" in result
assert result["webhook_provider"] == TestWebhookManager
class TestProviderBuilder:
"""Test the ProviderBuilder fluent API."""
def setup_method(self):
"""Clear registry before each test."""
AutoRegistry.clear()
def test_basic_provider_builder(self):
"""Test building a basic provider."""
provider = (
ProviderBuilder("test_provider")
.with_api_key("TEST_API_KEY", "Test API Key")
.build()
)
assert provider.name == "test_provider"
assert "api_key" in provider.supported_auth_types
assert AutoRegistry.get_provider("test_provider") == provider
def test_provider_builder_with_oauth(self):
"""Test building a provider with OAuth."""
class TestOAuth(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
provider = (
ProviderBuilder("oauth_test")
.with_oauth(TestOAuth, scopes=["read", "write"])
.build()
)
assert provider.oauth_config is not None
assert provider.oauth_config.oauth_handler == TestOAuth
assert "oauth2" in provider.supported_auth_types
def test_provider_builder_with_webhook(self):
"""Test building a provider with webhook manager."""
class TestWebhook(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
provider = (
ProviderBuilder("webhook_test").with_webhook_manager(TestWebhook).build()
)
assert provider.webhook_manager == TestWebhook
def test_provider_builder_with_base_cost(self):
"""Test building a provider with base costs."""
from backend.data.cost import BlockCostType
provider = (
ProviderBuilder("cost_test")
.with_base_cost(10, BlockCostType.RUN)
.with_base_cost(5, BlockCostType.BYTE)
.build()
)
assert len(provider.base_costs) == 2
assert provider.base_costs[0].cost_amount == 10
assert provider.base_costs[0].cost_type == BlockCostType.RUN
assert provider.base_costs[1].cost_amount == 5
assert provider.base_costs[1].cost_type == BlockCostType.BYTE
def test_provider_builder_with_api_client(self):
"""Test building a provider with API client factory."""
def mock_client_factory():
return Mock()
provider = (
ProviderBuilder("client_test").with_api_client(mock_client_factory).build()
)
assert provider._api_client_factory == mock_client_factory
def test_provider_builder_with_error_handler(self):
"""Test building a provider with error handler."""
def mock_error_handler(exc: Exception) -> str:
return f"Error: {str(exc)}"
provider = (
ProviderBuilder("error_test").with_error_handler(mock_error_handler).build()
)
assert provider._error_handler == mock_error_handler
def test_provider_builder_complete_example(self):
"""Test building a complete provider with all features."""
from backend.data.cost import BlockCostType
class TestOAuth(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.GITHUB
class TestWebhook(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
def client_factory():
return Mock()
def error_handler(exc):
return str(exc)
provider = (
ProviderBuilder("complete_test")
.with_api_key("COMPLETE_API_KEY", "Complete API Key")
.with_oauth(TestOAuth, scopes=["read"])
.with_webhook_manager(TestWebhook)
.with_base_cost(100, BlockCostType.RUN)
.with_api_client(client_factory)
.with_error_handler(error_handler)
.with_config(custom_setting="value")
.build()
)
# Verify all settings
assert provider.name == "complete_test"
assert "api_key" in provider.supported_auth_types
assert "oauth2" in provider.supported_auth_types
assert provider.oauth_config is not None
assert provider.oauth_config.oauth_handler == TestOAuth
assert provider.webhook_manager == TestWebhook
assert len(provider.base_costs) == 1
assert provider._api_client_factory == client_factory
assert provider._error_handler == error_handler
assert provider.get_config("custom_setting") == "value" # from with_config
# Verify it's registered
assert AutoRegistry.get_provider("complete_test") == provider
assert "complete_test" in AutoRegistry._oauth_handlers
assert "complete_test" in AutoRegistry._webhook_managers
class TestSDKImports:
"""Test that all expected exports are available from the SDK."""
def test_core_block_imports(self):
"""Test core block system imports."""
from backend.sdk import Block, BlockCategory
# Just verify they're importable
assert Block is not None
assert BlockCategory is not None
def test_schema_imports(self):
"""Test schema and model imports."""
from backend.sdk import APIKeyCredentials, SchemaField
assert SchemaField is not None
assert APIKeyCredentials is not None
def test_type_alias_imports(self):
"""Test type alias imports are removed."""
# Type aliases have been removed from SDK
# Users should import from typing or use built-in types directly
pass
def test_cost_system_imports(self):
"""Test cost system imports."""
from backend.sdk import BlockCost, BlockCostType
assert BlockCost is not None
assert BlockCostType is not None
def test_utility_imports(self):
"""Test utility imports."""
from backend.sdk import BaseModel, Requests, json
assert json is not None
assert BaseModel is not None
assert Requests is not None
def test_integration_imports(self):
"""Test integration imports."""
from backend.sdk import ProviderName
assert ProviderName is not None
def test_sdk_component_imports(self):
"""Test SDK-specific component imports."""
from backend.sdk import AutoRegistry, ProviderBuilder
assert AutoRegistry is not None
assert ProviderBuilder is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,506 @@
"""
Tests for SDK webhook functionality.
This test suite verifies webhook blocks and webhook manager integration.
"""
from enum import Enum
import pytest
from backend.integrations.providers import ProviderName
from backend.sdk import (
APIKeyCredentials,
AutoRegistry,
BaseModel,
BaseWebhooksManager,
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockWebhookConfig,
CredentialsField,
CredentialsMetaInput,
Field,
ProviderBuilder,
SchemaField,
SecretStr,
)
class TestWebhookTypes(str, Enum):
"""Test webhook event types."""
CREATED = "created"
UPDATED = "updated"
DELETED = "deleted"
class TestWebhooksManager(BaseWebhooksManager):
"""Test webhook manager implementation."""
PROVIDER_NAME = ProviderName.GITHUB # Reuse for testing
class WebhookType(str, Enum):
TEST = "test"
@classmethod
async def validate_payload(cls, webhook, request):
"""Validate incoming webhook payload."""
# Mock implementation
payload = {"test": "data"}
event_type = "test_event"
return payload, event_type
async def _register_webhook(
self,
credentials,
webhook_type: str,
resource: str,
events: list[str],
ingress_url: str,
secret: str,
) -> tuple[str, dict]:
"""Register webhook with external service."""
# Mock implementation
webhook_id = f"test_webhook_{resource}"
config = {
"webhook_type": webhook_type,
"resource": resource,
"events": events,
"url": ingress_url,
}
return webhook_id, config
async def _deregister_webhook(self, webhook, credentials) -> None:
"""Deregister webhook from external service."""
# Mock implementation
pass
class TestWebhookBlock(Block):
"""Test webhook block implementation."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_webhooks",
supported_credential_types={"api_key"},
description="Webhook service credentials",
)
webhook_url: str = SchemaField(
description="URL to receive webhooks",
)
resource_id: str = SchemaField(
description="Resource to monitor",
)
events: list[TestWebhookTypes] = SchemaField(
description="Events to listen for",
default=[TestWebhookTypes.CREATED],
)
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
webhook_id: str = SchemaField(description="Registered webhook ID")
is_active: bool = SchemaField(description="Webhook is active")
event_count: int = SchemaField(description="Number of events configured")
def __init__(self):
super().__init__(
id="test-webhook-block",
description="Test webhook block",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=TestWebhookBlock.Input,
output_schema=TestWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="test_webhooks", # type: ignore
webhook_type="test",
resource_format="{resource_id}",
),
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
# Simulate webhook registration
webhook_id = f"webhook_{input_data.resource_id}"
yield "webhook_id", webhook_id
yield "is_active", True
yield "event_count", len(input_data.events)
class TestWebhookBlockCreation:
"""Test creating webhook blocks with the SDK."""
def setup_method(self):
"""Set up test environment."""
AutoRegistry.clear()
# Register a provider with webhook support
self.provider = (
ProviderBuilder("test_webhooks")
.with_api_key("TEST_WEBHOOK_KEY", "Test Webhook API Key")
.with_webhook_manager(TestWebhooksManager)
.build()
)
@pytest.mark.asyncio
async def test_basic_webhook_block(self):
"""Test creating a basic webhook block."""
block = TestWebhookBlock()
# Verify block configuration
assert block.webhook_config is not None
assert block.webhook_config.provider == "test_webhooks"
assert block.webhook_config.webhook_type == "test"
assert "{resource_id}" in block.webhook_config.resource_format # type: ignore
# Test block execution
test_creds = APIKeyCredentials(
id="test-webhook-creds",
provider="test_webhooks",
api_key=SecretStr("test-key"),
title="Test Webhook Key",
)
outputs = {}
async for name, value in block.run(
TestWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-webhook-creds",
"type": "api_key",
},
webhook_url="https://example.com/webhook",
resource_id="resource_123",
events=[TestWebhookTypes.CREATED, TestWebhookTypes.UPDATED],
),
credentials=test_creds,
):
outputs[name] = value
assert outputs["webhook_id"] == "webhook_resource_123"
assert outputs["is_active"] is True
assert outputs["event_count"] == 2
@pytest.mark.asyncio
async def test_webhook_block_with_filters(self):
"""Test webhook block with event filters."""
class EventFilterModel(BaseModel):
include_system: bool = Field(default=False)
severity_levels: list[str] = Field(
default_factory=lambda: ["info", "warning"]
)
class FilteredWebhookBlock(Block):
"""Webhook block with filtering."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="test_webhooks",
supported_credential_types={"api_key"},
)
resource: str = SchemaField(description="Resource to monitor")
filters: EventFilterModel = SchemaField(
description="Event filters",
default_factory=EventFilterModel,
)
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
webhook_active: bool = SchemaField(description="Webhook active")
filter_summary: str = SchemaField(description="Active filters")
def __init__(self):
super().__init__(
id="filtered-webhook-block",
description="Webhook with filters",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=FilteredWebhookBlock.Input,
output_schema=FilteredWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="test_webhooks", # type: ignore
webhook_type="filtered",
resource_format="{resource}",
),
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
filters = input_data.filters
filter_parts = []
if filters.include_system:
filter_parts.append("system events")
filter_parts.append(f"{len(filters.severity_levels)} severity levels")
yield "webhook_active", True
yield "filter_summary", ", ".join(filter_parts)
# Test the block
block = FilteredWebhookBlock()
test_creds = APIKeyCredentials(
id="test-creds",
provider="test_webhooks",
api_key=SecretStr("key"),
title="Test Key",
)
# Test with default filters
outputs = {}
async for name, value in block.run(
FilteredWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-creds",
"type": "api_key",
},
resource="test_resource",
),
credentials=test_creds,
):
outputs[name] = value
assert outputs["webhook_active"] is True
assert "2 severity levels" in outputs["filter_summary"]
# Test with custom filters
custom_filters = EventFilterModel(
include_system=True,
severity_levels=["error", "critical"],
)
outputs = {}
async for name, value in block.run(
FilteredWebhookBlock.Input(
credentials={ # type: ignore
"provider": "test_webhooks",
"id": "test-creds",
"type": "api_key",
},
resource="test_resource",
filters=custom_filters,
),
credentials=test_creds,
):
outputs[name] = value
assert "system events" in outputs["filter_summary"]
assert "2 severity levels" in outputs["filter_summary"]
class TestWebhookManagerIntegration:
"""Test webhook manager integration with AutoRegistry."""
def setup_method(self):
"""Clear registry."""
AutoRegistry.clear()
def test_webhook_manager_registration(self):
"""Test that webhook managers are properly registered."""
# Create multiple webhook managers
class WebhookManager1(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GITHUB
class WebhookManager2(BaseWebhooksManager):
PROVIDER_NAME = ProviderName.GOOGLE
# Register providers with webhook managers
(
ProviderBuilder("webhook_service_1")
.with_webhook_manager(WebhookManager1)
.build()
)
(
ProviderBuilder("webhook_service_2")
.with_webhook_manager(WebhookManager2)
.build()
)
# Verify registration
managers = AutoRegistry.get_webhook_managers()
assert "webhook_service_1" in managers
assert "webhook_service_2" in managers
assert managers["webhook_service_1"] == WebhookManager1
assert managers["webhook_service_2"] == WebhookManager2
@pytest.mark.asyncio
async def test_webhook_block_with_provider_manager(self):
"""Test webhook block using a provider's webhook manager."""
# Register provider with webhook manager
(
ProviderBuilder("integrated_webhooks")
.with_api_key("INTEGRATED_KEY", "Integrated Webhook Key")
.with_webhook_manager(TestWebhooksManager)
.build()
)
# Create a block that uses this provider
class IntegratedWebhookBlock(Block):
"""Block using integrated webhook manager."""
class Input(BlockSchema):
credentials: CredentialsMetaInput = CredentialsField(
provider="integrated_webhooks",
supported_credential_types={"api_key"},
)
target: str = SchemaField(description="Webhook target")
payload: dict = SchemaField(
description="Webhook payload",
default={},
)
class Output(BlockSchema):
status: str = SchemaField(description="Webhook status")
manager_type: str = SchemaField(description="Manager type used")
def __init__(self):
super().__init__(
id="integrated-webhook-block",
description="Uses integrated webhook manager",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=IntegratedWebhookBlock.Input,
output_schema=IntegratedWebhookBlock.Output,
webhook_config=BlockWebhookConfig(
provider="integrated_webhooks", # type: ignore
webhook_type=TestWebhooksManager.WebhookType.TEST,
resource_format="{target}",
),
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Get the webhook manager for this provider
managers = AutoRegistry.get_webhook_managers()
manager_class = managers.get("integrated_webhooks")
yield "status", "configured"
yield "manager_type", (
manager_class.__name__ if manager_class else "none"
)
# Test the block
block = IntegratedWebhookBlock()
test_creds = APIKeyCredentials(
id="integrated-creds",
provider="integrated_webhooks",
api_key=SecretStr("key"),
title="Integrated Key",
)
outputs = {}
async for name, value in block.run(
IntegratedWebhookBlock.Input(
credentials={ # type: ignore
"provider": "integrated_webhooks",
"id": "integrated-creds",
"type": "api_key",
},
target="test_target",
),
credentials=test_creds,
):
outputs[name] = value
assert outputs["status"] == "configured"
assert outputs["manager_type"] == "TestWebhooksManager"
class TestWebhookEventHandling:
"""Test webhook event handling in blocks."""
@pytest.mark.asyncio
async def test_webhook_event_processing_block(self):
"""Test a block that processes webhook events."""
class WebhookEventBlock(Block):
"""Block that processes webhook events."""
class Input(BlockSchema):
event_type: str = SchemaField(description="Type of webhook event")
payload: dict = SchemaField(description="Webhook payload")
verify_signature: bool = SchemaField(
description="Whether to verify webhook signature",
default=True,
)
class Output(BlockSchema):
processed: bool = SchemaField(description="Event was processed")
event_summary: str = SchemaField(description="Summary of event")
action_required: bool = SchemaField(description="Action required")
def __init__(self):
super().__init__(
id="webhook-event-processor",
description="Processes incoming webhook events",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=WebhookEventBlock.Input,
output_schema=WebhookEventBlock.Output,
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
# Process based on event type
event_type = input_data.event_type
payload = input_data.payload
if event_type == "created":
summary = f"New item created: {payload.get('id', 'unknown')}"
action_required = True
elif event_type == "updated":
summary = f"Item updated: {payload.get('id', 'unknown')}"
action_required = False
elif event_type == "deleted":
summary = f"Item deleted: {payload.get('id', 'unknown')}"
action_required = True
else:
summary = f"Unknown event: {event_type}"
action_required = False
yield "processed", True
yield "event_summary", summary
yield "action_required", action_required
# Test the block with different events
block = WebhookEventBlock()
# Test created event
outputs = {}
async for name, value in block.run(
WebhookEventBlock.Input(
event_type="created",
payload={"id": "123", "name": "Test Item"},
)
):
outputs[name] = value
assert outputs["processed"] is True
assert "New item created: 123" in outputs["event_summary"]
assert outputs["action_required"] is True
# Test updated event
outputs = {}
async for name, value in block.run(
WebhookEventBlock.Input(
event_type="updated",
payload={"id": "456", "changes": ["name", "status"]},
)
):
outputs[name] = value
assert outputs["processed"] is True
assert "Item updated: 456" in outputs["event_summary"]
assert outputs["action_required"] is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,3 +1,21 @@
"""
Test Data Creator for AutoGPT Platform
This script creates test data for the AutoGPT platform database.
Image/Video URL Domains Used:
- Images: picsum.photos (for all image URLs - avatars, store listing images, etc.)
- Videos: youtube.com (for store listing video URLs)
Add these domains to your Next.js config:
```javascript
// next.config.js
images: {
domains: ['picsum.photos'],
}
```
"""
import asyncio
import random
from datetime import datetime
@@ -14,6 +32,7 @@ from prisma.types import (
AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput,
IntegrationWebhookCreateInput,
ProfileCreateInput,
StoreListingReviewCreateInput,
UserCreateInput,
@@ -53,10 +72,26 @@ MAX_REVIEWS_PER_VERSION = 5 # Total reviews depends on number of versions creat
def get_image():
url = faker.image_url()
while "placekitten.com" in url:
url = faker.image_url()
return url
"""Generate a consistent image URL using picsum.photos service."""
width = random.choice([200, 300, 400, 500, 600, 800])
height = random.choice([200, 300, 400, 500, 600, 800])
# Use a random seed to get different images
seed = random.randint(1, 1000)
return f"https://picsum.photos/seed/{seed}/{width}/{height}"
def get_video_url():
"""Generate a consistent video URL using a placeholder service."""
# Using YouTube as a consistent source for video URLs
video_ids = [
"dQw4w9WgXcQ", # Example video IDs
"9bZkp7q19f0",
"kJQP7kiw5Fk",
"RgKAFK5djSk",
"L_jWHffIx5E",
]
video_id = random.choice(video_ids)
return f"https://www.youtube.com/watch?v={video_id}"
async def main():
@@ -147,12 +182,27 @@ async def main():
)
agent_presets.append(preset)
# Insert UserAgents
user_agents = []
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
# Insert Profiles first (before LibraryAgents)
profiles = []
print(f"Inserting {NUM_USERS} profiles")
for user in users:
profile = await db.profile.create(
data=ProfileCreateInput(
userId=user.id,
name=user.name or faker.name(),
username=faker.unique.user_name(),
description=faker.text(),
links=[faker.url() for _ in range(3)],
avatarUrl=get_image(),
)
)
profiles.append(profile)
# Insert LibraryAgents
library_agents = []
print("Inserting library agents")
for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
# Get a shuffled list of graphs to ensure uniqueness per user
available_graphs = agent_graphs.copy()
random.shuffle(available_graphs)
@@ -162,18 +212,27 @@ async def main():
for i in range(num_agents):
graph = available_graphs[i] # Use unique graph for each library agent
user_agent = await db.libraryagent.create(
# Get creator profile for this graph's owner
creator_profile = next(
(p for p in profiles if p.userId == graph.userId), None
)
library_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"creatorId": creator_profile.id if creator_profile else None,
"imageUrl": get_image() if random.random() < 0.5 else None,
"useGraphIsActiveVersion": random.choice([True, False]),
"isFavorite": random.choice([True, False]),
"isCreatedByUser": random.choice([True, False]),
"isArchived": random.choice([True, False]),
"isDeleted": random.choice([True, False]),
}
)
user_agents.append(user_agent)
library_agents.append(library_agent)
# Insert AgentGraphExecutions
agent_graph_executions = []
@@ -325,25 +384,9 @@ async def main():
)
)
# Insert Profiles
profiles = []
print(f"Inserting {NUM_USERS} profiles")
for user in users:
profile = await db.profile.create(
data=ProfileCreateInput(
userId=user.id,
name=user.name or faker.name(),
username=faker.unique.user_name(),
description=faker.text(),
links=[faker.url() for _ in range(3)],
avatarUrl=get_image(),
)
)
profiles.append(profile)
# Insert StoreListings
store_listings = []
print(f"Inserting {NUM_USERS} store listings")
print("Inserting store listings")
for graph in agent_graphs:
user = random.choice(users)
slug = faker.slug()
@@ -360,7 +403,7 @@ async def main():
# Insert StoreListingVersions
store_listing_versions = []
print(f"Inserting {NUM_USERS} store listing versions")
print("Inserting store listing versions")
for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create(
@@ -369,7 +412,7 @@ async def main():
"agentGraphVersion": graph.version,
"name": graph.name or faker.sentence(nb_words=3),
"subHeading": faker.sentence(),
"videoUrl": faker.url(),
"videoUrl": get_video_url() if random.random() < 0.3 else None,
"imageUrls": [get_image() for _ in range(3)],
"description": faker.text(),
"categories": [faker.word() for _ in range(3)],
@@ -388,7 +431,7 @@ async def main():
store_listing_versions.append(version)
# Insert StoreListingReviews
print(f"Inserting {NUM_USERS * MAX_REVIEWS_PER_VERSION} store listing reviews")
print("Inserting store listing reviews")
for version in store_listing_versions:
# Create a copy of users list and shuffle it to avoid duplicates
available_reviewers = users.copy()
@@ -411,26 +454,92 @@ async def main():
)
)
# Update StoreListingVersions with submission status (StoreListingSubmissions table no longer exists)
print(f"Updating {NUM_USERS} store listing versions with submission status")
for version in store_listing_versions:
reviewer = random.choice(users)
status: prisma.enums.SubmissionStatus = random.choice(
[
prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED,
]
)
await db.storelistingversion.update(
where={"id": version.id},
data={
"submissionStatus": status,
"Reviewer": {"connect": {"id": reviewer.id}},
"reviewComments": faker.text(),
"reviewedAt": datetime.now(),
},
)
# Insert UserOnboarding for some users
print("Inserting user onboarding data")
for user in random.sample(
users, k=int(NUM_USERS * 0.7)
): # 70% of users have onboarding data
completed_steps = []
possible_steps = list(prisma.enums.OnboardingStep)
# Randomly complete some steps
if random.random() < 0.8:
num_steps = random.randint(1, len(possible_steps))
completed_steps = random.sample(possible_steps, k=num_steps)
try:
await db.useronboarding.create(
data={
"userId": user.id,
"completedSteps": completed_steps,
"notificationDot": random.choice([True, False]),
"notified": (
random.sample(completed_steps, k=min(3, len(completed_steps)))
if completed_steps
else []
),
"rewardedFor": (
random.sample(completed_steps, k=min(2, len(completed_steps)))
if completed_steps
else []
),
"usageReason": (
random.choice(["personal", "business", "research", "learning"])
if random.random() < 0.7
else None
),
"integrations": random.sample(
["github", "google", "discord", "slack"], k=random.randint(0, 2)
),
"otherIntegrations": (
faker.word() if random.random() < 0.2 else None
),
"selectedStoreListingVersionId": (
random.choice(store_listing_versions).id
if store_listing_versions and random.random() < 0.5
else None
),
"agentInput": (
Json({"test": "data"}) if random.random() < 0.3 else None
),
"onboardingAgentExecutionId": (
random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3
else None
),
"agentRuns": random.randint(0, 10),
}
)
except Exception as e:
print(f"Error creating onboarding for user {user.id}: {e}")
# Try simpler version
await db.useronboarding.create(
data={
"userId": user.id,
}
)
# Insert IntegrationWebhooks for some users
print("Inserting integration webhooks")
for user in random.sample(
users, k=int(NUM_USERS * 0.3)
): # 30% of users have webhooks
for _ in range(random.randint(1, 3)):
await db.integrationwebhook.create(
data=IntegrationWebhookCreateInput(
userId=user.id,
provider=random.choice(["github", "slack", "discord"]),
credentialsId=str(faker.uuid4()),
webhookType=random.choice(["repo", "channel", "server"]),
resource=faker.slug(),
events=[
random.choice(["created", "updated", "deleted"])
for _ in range(random.randint(1, 3))
],
config=prisma.Json({"url": faker.url()}),
secret=str(faker.sha256()),
providerWebhookId=str(faker.uuid4()),
)
)
# Insert APIKeys
print(f"Inserting {NUM_USERS} api keys")
@@ -451,7 +560,12 @@ async def main():
)
)
# Refresh materialized views
print("Refreshing materialized views...")
await db.execute_raw("SELECT refresh_store_materialized_views();")
await db.disconnect()
print("Test data creation completed successfully!")
if __name__ == "__main__":

View File

@@ -0,0 +1,323 @@
#!/usr/bin/env python3
"""
Test Data Updater for Store Materialized Views
This script updates existing test data to trigger changes in the materialized views:
- mv_agent_run_counts: Updated by creating new AgentGraphExecution records
- mv_review_stats: Updated by creating new StoreListingReview records
Run this after test_data_creator.py to test that materialized views update correctly.
"""
import asyncio
import random
from datetime import datetime, timedelta
import prisma.enums
from faker import Faker
from prisma import Json, Prisma
faker = Faker()
async def main():
db = Prisma()
await db.connect()
print("Starting test data updates for materialized views...")
print("=" * 60)
# Get existing data
users = await db.user.find_many(take=50)
agent_graphs = await db.agentgraph.find_many(where={"isActive": True}, take=50)
store_listings = await db.storelisting.find_many(
where={"hasApprovedVersion": True}, include={"Versions": True}, take=30
)
agent_nodes = await db.agentnode.find_many(take=100)
if not all([users, agent_graphs, store_listings]):
print(
"ERROR: Not enough test data found. Please run test_data_creator.py first."
)
await db.disconnect()
return
print(
f"Found {len(users)} users, {len(agent_graphs)} graphs, {len(store_listings)} store listings"
)
print()
# 1. Add new AgentGraphExecutions to update mv_agent_run_counts
print("1. Adding new agent graph executions...")
print("-" * 40)
new_executions_count = 0
execution_data = []
for graph in random.sample(agent_graphs, min(20, len(agent_graphs))):
# Add 5-15 new executions per selected graph
num_new_executions = random.randint(5, 15)
for _ in range(num_new_executions):
user = random.choice(users)
execution_data.append(
{
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"userId": user.id,
"executionStatus": random.choice(
[
prisma.enums.AgentExecutionStatus.COMPLETED,
prisma.enums.AgentExecutionStatus.FAILED,
prisma.enums.AgentExecutionStatus.RUNNING,
]
),
"startedAt": faker.date_time_between(
start_date="-7d", end_date="now"
),
"stats": Json(
{
"duration": random.randint(100, 5000),
"blocks_executed": random.randint(1, 10),
}
),
}
)
new_executions_count += 1
# Batch create executions
await db.agentgraphexecution.create_many(data=execution_data)
print(f"✓ Created {new_executions_count} new executions")
# Get the created executions for node executions
recent_executions = await db.agentgraphexecution.find_many(
take=new_executions_count, order={"createdAt": "desc"}
)
# 2. Add corresponding AgentNodeExecutions
print("\n2. Adding agent node executions...")
print("-" * 40)
node_execution_data = []
for execution in recent_executions:
# Get nodes for this graph
graph_nodes = [
n for n in agent_nodes if n.agentGraphId == execution.agentGraphId
]
if graph_nodes:
for node in random.sample(graph_nodes, min(3, len(graph_nodes))):
node_execution_data.append(
{
"agentGraphExecutionId": execution.id,
"agentNodeId": node.id,
"executionStatus": execution.executionStatus,
"addedTime": datetime.now(),
"startedTime": datetime.now()
- timedelta(minutes=random.randint(1, 10)),
"endedTime": (
datetime.now()
if execution.executionStatus
== prisma.enums.AgentExecutionStatus.COMPLETED
else None
),
}
)
await db.agentnodeexecution.create_many(data=node_execution_data)
print(f"✓ Created {len(node_execution_data)} node executions")
# 3. Add new StoreListingReviews to update mv_review_stats
print("\n3. Adding new store listing reviews...")
print("-" * 40)
new_reviews_count = 0
for listing in store_listings:
if not listing.Versions:
continue
# Get approved versions
approved_versions = [
v
for v in listing.Versions
if v.submissionStatus == prisma.enums.SubmissionStatus.APPROVED
]
if not approved_versions:
continue
# Pick a version to add reviews to
version = random.choice(approved_versions)
# Get existing reviews for this version to avoid duplicates
existing_reviews = await db.storelistingreview.find_many(
where={"storeListingVersionId": version.id}
)
existing_reviewer_ids = {r.reviewByUserId for r in existing_reviews}
# Find users who haven't reviewed this version yet
available_reviewers = [u for u in users if u.id not in existing_reviewer_ids]
if available_reviewers:
# Add 2-5 new reviews
num_new_reviews = min(random.randint(2, 5), len(available_reviewers))
selected_reviewers = random.sample(available_reviewers, num_new_reviews)
for reviewer in selected_reviewers:
# Bias towards positive reviews (4-5 stars)
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
await db.storelistingreview.create(
data={
"storeListingVersionId": version.id,
"reviewByUserId": reviewer.id,
"score": score,
"comments": (
faker.text(max_nb_chars=200)
if random.random() < 0.7
else None
),
}
)
new_reviews_count += 1
print(f"✓ Created {new_reviews_count} new reviews")
# 4. Update some store listing versions (change categories, featured status)
print("\n4. Updating store listing versions...")
print("-" * 40)
updates_count = 0
for listing in random.sample(store_listings, min(10, len(store_listings))):
if listing.Versions:
version = random.choice(listing.Versions)
if version.submissionStatus == prisma.enums.SubmissionStatus.APPROVED:
# Toggle featured status or update categories
new_categories = random.sample(
[
"productivity",
"ai",
"automation",
"data",
"social",
"marketing",
"development",
"analytics",
],
k=random.randint(2, 4),
)
await db.storelistingversion.update(
where={"id": version.id},
data={
"isFeatured": (
not version.isFeatured
if random.random() < 0.3
else version.isFeatured
),
"categories": new_categories,
"updatedAt": datetime.now(),
},
)
updates_count += 1
print(f"✓ Updated {updates_count} store listing versions")
# 5. Create some new credit transactions
print("\n5. Adding credit transactions...")
print("-" * 40)
transaction_count = 0
for user in random.sample(users, min(30, len(users))):
# Add 1-3 transactions per user
for _ in range(random.randint(1, 3)):
transaction_type = random.choice(
[
prisma.enums.CreditTransactionType.USAGE,
prisma.enums.CreditTransactionType.TOP_UP,
prisma.enums.CreditTransactionType.GRANT,
]
)
amount = (
random.randint(10, 500)
if transaction_type == prisma.enums.CreditTransactionType.TOP_UP
else -random.randint(1, 50)
)
await db.credittransaction.create(
data={
"userId": user.id,
"amount": amount,
"type": transaction_type,
"metadata": Json(
{
"source": "test_updater",
"timestamp": datetime.now().isoformat(),
}
),
}
)
transaction_count += 1
print(f"✓ Created {transaction_count} credit transactions")
# 6. Refresh materialized views
print("\n6. Refreshing materialized views...")
print("-" * 40)
try:
await db.execute_raw("SELECT refresh_store_materialized_views();")
print("✓ Materialized views refreshed successfully")
except Exception as e:
print(f"⚠ Warning: Could not refresh materialized views: {e}")
print(
" You may need to refresh them manually with: SELECT refresh_store_materialized_views();"
)
# 7. Verify the updates
print("\n7. Verifying updates...")
print("-" * 40)
# Check agent run counts
run_counts = await db.query_raw(
"SELECT COUNT(*) as view_count FROM mv_agent_run_counts"
)
print(f"✓ mv_agent_run_counts has {run_counts[0]['view_count']} entries")
# Check review stats
review_stats = await db.query_raw(
"SELECT COUNT(*) as view_count FROM mv_review_stats"
)
print(f"✓ mv_review_stats has {review_stats[0]['view_count']} entries")
# Sample some data from the views
print("\nSample data from materialized views:")
sample_runs = await db.query_raw(
"SELECT * FROM mv_agent_run_counts ORDER BY run_count DESC LIMIT 5"
)
print("\nTop 5 agents by run count:")
for row in sample_runs:
print(f" - Agent {row['agentGraphId'][:8]}...: {row['run_count']} runs")
sample_reviews = await db.query_raw(
"SELECT * FROM mv_review_stats ORDER BY avg_rating DESC NULLS LAST LIMIT 5"
)
print("\nTop 5 store listings by rating:")
for row in sample_reviews:
avg_rating = row["avg_rating"] if row["avg_rating"] is not None else 0.0
print(
f" - Listing {row['storeListingId'][:8]}...: {avg_rating:.2f} ⭐ ({row['review_count']} reviews)"
)
await db.disconnect()
print("\n" + "=" * 60)
print("Test data update completed successfully!")
print("The materialized views should now reflect the updated data.")
print(
"\nTo manually refresh views, run: SELECT refresh_store_materialized_views();"
)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -11,8 +11,6 @@ const nextConfig = {
"ideogram.ai", // for generated images
"picsum.photos", // for placeholder images
"dummyimage.com", // for placeholder images
"placekitten.com", // for placeholder images
],
},
output: "standalone",
@@ -30,6 +28,11 @@ export default isDevelopmentBuild
org: "significant-gravitas",
project: "builder",
// Expose Vercel env to the client
env: {
NEXT_PUBLIC_VERCEL_ENV: process.env.VERCEL_ENV,
},
// Only print logs for uploading source maps in CI
silent: !process.env.CI,

View File

@@ -54,7 +54,6 @@
"@supabase/supabase-js": "2.50.3",
"@tanstack/react-query": "5.81.5",
"@tanstack/react-table": "8.21.3",
"@tanstack/react-virtual": "3.13.12",
"@types/jaro-winkler": "0.2.4",
"@xyflow/react": "12.8.1",
"ajv": "8.17.1",
@@ -76,6 +75,7 @@
"moment": "2.30.1",
"next": "15.3.5",
"next-themes": "0.4.6",
"nuqs": "2.4.3",
"party-js": "2.2.0",
"react": "18.3.1",
"react-day-picker": "9.8.0",
@@ -88,6 +88,7 @@
"react-shepherd": "6.1.8",
"recharts": "2.15.3",
"shepherd.js": "14.5.0",
"sonner": "2.0.6",
"tailwind-merge": "2.6.0",
"tailwindcss-animate": "1.0.7",
"uuid": "11.1.0",

View File

@@ -92,9 +92,6 @@ importers:
'@tanstack/react-table':
specifier: 8.21.3
version: 8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@tanstack/react-virtual':
specifier: 3.13.12
version: 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
'@types/jaro-winkler':
specifier: 0.2.4
version: 0.2.4
@@ -158,6 +155,9 @@ importers:
next-themes:
specifier: 0.4.6
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
nuqs:
specifier: 2.4.3
version: 2.4.3(next@15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
party-js:
specifier: 2.2.0
version: 2.2.0
@@ -194,6 +194,9 @@ importers:
shepherd.js:
specifier: 14.5.0
version: 14.5.0
sonner:
specifier: 2.0.6
version: 2.0.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
tailwind-merge:
specifier: 2.6.0
version: 2.6.0
@@ -2751,19 +2754,10 @@ packages:
react: '>=16.8'
react-dom: '>=16.8'
'@tanstack/react-virtual@3.13.12':
resolution: {integrity: sha512-Gd13QdxPSukP8ZrkbgS2RwoZseTTbQPLnQEn7HY/rqtM+8Zt95f7xKC7N0EsKs7aoz0WzZ+fditZux+F8EzYxA==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
'@tanstack/table-core@8.21.3':
resolution: {integrity: sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==}
engines: {node: '>=12'}
'@tanstack/virtual-core@3.13.12':
resolution: {integrity: sha512-1YBOJfRHV4sXUmWsFSf5rQor4Ss82G8dQWLRbnk3GA4jeP8hQt1hxXh0tmflpC0dz3VgEv/1+qwPyLeWkQuPFA==}
'@testing-library/dom@10.4.0':
resolution: {integrity: sha512-pemlzrSESWbdAloYml3bAJMEfNh1Z7EduzqPKprCH5S341frlpYnUEW0H72dLxa6IsYr+mPno20GiSm+h9dEdQ==}
engines: {node: '>=18'}
@@ -5338,6 +5332,9 @@ packages:
resolution: {integrity: sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==}
engines: {node: '>=16 || 14 >=14.17'}
mitt@3.0.1:
resolution: {integrity: sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==}
module-details-from-path@1.0.4:
resolution: {integrity: sha512-EGWKgxALGMgzvxYF1UyGTy0HXX/2vHLkw6+NvDKW2jypWbHpjQuj4UMcqQWXHERJhVGKikolT06G3bcKe4fi7w==}
@@ -5468,6 +5465,24 @@ packages:
nth-check@2.1.1:
resolution: {integrity: sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w==}
nuqs@2.4.3:
resolution: {integrity: sha512-BgtlYpvRwLYiJuWzxt34q2bXu/AIS66sLU1QePIMr2LWkb+XH0vKXdbLSgn9t6p7QKzwI7f38rX3Wl9llTXQ8Q==}
peerDependencies:
'@remix-run/react': '>=2'
next: '>=14.2.0'
react: '>=18.2.0 || ^19.0.0-0'
react-router: ^6 || ^7
react-router-dom: ^6 || ^7
peerDependenciesMeta:
'@remix-run/react':
optional: true
next:
optional: true
react-router:
optional: true
react-router-dom:
optional: true
oas-kit-common@1.0.8:
resolution: {integrity: sha512-pJTS2+T0oGIwgjGpw7sIRU8RQMcUoKCDWFLdBqKB2BNmGpbBMH2sdqAaOXUg8OzonZHU0L7vfJu1mJFEiYDWOQ==}
@@ -6388,6 +6403,12 @@ packages:
resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==}
engines: {node: '>=8'}
sonner@2.0.6:
resolution: {integrity: sha512-yHFhk8T/DK3YxjFQXIrcHT1rGEeTLliVzWbO0xN8GberVun2RiBnxAjXAYpZrqwEVHBG9asI/Li8TAAhN9m59Q==}
peerDependencies:
react: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
react-dom: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
source-map-js@1.2.1:
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
engines: {node: '>=0.10.0'}
@@ -9994,16 +10015,8 @@ snapshots:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@tanstack/react-virtual@3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@tanstack/virtual-core': 3.13.12
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
'@tanstack/table-core@8.21.3': {}
'@tanstack/virtual-core@3.13.12': {}
'@testing-library/dom@10.4.0':
dependencies:
'@babel/code-frame': 7.27.1
@@ -11653,8 +11666,8 @@ snapshots:
'@typescript-eslint/parser': 8.36.0(eslint@8.57.1)(typescript@5.8.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
eslint-plugin-jsx-a11y: 6.10.2(eslint@8.57.1)
eslint-plugin-react: 7.37.5(eslint@8.57.1)
eslint-plugin-react-hooks: 5.2.0(eslint@8.57.1)
@@ -11673,7 +11686,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1):
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
dependencies:
'@nolyfill/is-core-module': 1.0.39
debug: 4.4.1
@@ -11684,22 +11697,22 @@ snapshots:
tinyglobby: 0.2.14
unrs-resolver: 1.11.0
optionalDependencies:
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
eslint-module-utils@2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
dependencies:
debug: 3.2.7
optionalDependencies:
'@typescript-eslint/parser': 8.36.0(eslint@8.57.1)(typescript@5.8.3)
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1)
eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1)
transitivePeerDependencies:
- supports-color
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1):
eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1):
dependencies:
'@rtsao/scc': 1.1.0
array-includes: 3.1.9
@@ -11710,7 +11723,7 @@ snapshots:
doctrine: 2.1.0
eslint: 8.57.1
eslint-import-resolver-node: 0.3.9
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1)
eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.36.0(eslint@8.57.1)(typescript@5.8.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.10.1)(eslint@8.57.1)
hasown: 2.0.2
is-core-module: 2.16.1
is-glob: 4.0.3
@@ -13009,6 +13022,8 @@ snapshots:
minipass@7.1.2: {}
mitt@3.0.1: {}
module-details-from-path@1.0.4: {}
moment@2.30.1: {}
@@ -13171,6 +13186,13 @@ snapshots:
dependencies:
boolbase: 1.0.0
nuqs@2.4.3(next@15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
dependencies:
mitt: 3.0.1
react: 18.3.1
optionalDependencies:
next: 15.3.5(@babel/core@7.28.0)(@opentelemetry/api@1.9.0)(@playwright/test@1.53.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
oas-kit-common@1.0.8:
dependencies:
fast-safe-stringify: 2.1.1
@@ -14214,6 +14236,11 @@ snapshots:
slash@3.0.0: {}
sonner@2.0.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
source-map-js@1.2.1: {}
source-map-support@0.5.21:

View File

@@ -0,0 +1,6 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M19.99 10.1871C19.99 9.36767 19.9246 8.76973 19.7831 8.14966H10.1943V11.8493H15.8207C15.7062 12.7676 15.0943 14.1618 13.7567 15.0492L13.7398 15.1632L16.7444 17.4429L16.9637 17.4648C18.8825 15.7291 19.99 13.2042 19.99 10.1871Z" fill="#4285F4"/>
<path d="M10.1943 19.9313C12.9592 19.9313 15.2429 19.0454 16.9637 17.4648L13.7567 15.0492C12.8697 15.6438 11.7348 16.0244 10.1943 16.0244C7.50242 16.0244 5.25023 14.2886 4.39644 11.9036L4.28823 11.9125L1.17021 14.2775L1.13477 14.3808C2.84508 17.8028 6.1992 19.9313 10.1943 19.9313Z" fill="#34A853"/>
<path d="M4.39644 11.9036C4.1758 11.2746 4.04876 10.6013 4.04876 9.90569C4.04876 9.21011 4.1758 8.53684 4.38177 7.90781L4.37563 7.7883L1.20776 5.3801L1.13477 5.41253C0.436264 6.80439 0.0390625 8.35202 0.0390625 9.90569C0.0390625 11.4594 0.436264 13.007 1.13477 14.3808L4.39644 11.9036Z" fill="#FBBC05"/>
<path d="M10.1943 3.78682C12.1168 3.78682 13.397 4.66154 14.1236 5.33481L17.0194 2.59768C15.2373 0.953818 12.9592 0 10.1943 0C6.1992 0 2.84508 2.12847 1.13477 5.41253L4.38177 7.90781C5.25023 5.52278 7.50242 3.78682 10.1943 3.78682Z" fill="#EB4335"/>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -11,7 +11,7 @@ import StarRating from "@/components/onboarding/StarRating";
import SchemaTooltip from "@/components/SchemaTooltip";
import { TypeBasedInput } from "@/components/type-based-input";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { useToast } from "@/components/ui/use-toast";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { cn } from "@/lib/utils";
@@ -46,13 +46,13 @@ export default function Page() {
setStoreAgent(storeAgent);
});
api
.getAgentMetaByStoreListingVersionId(state?.selectedStoreListingVersionId)
.getGraphMetaByStoreListingVersionID(state.selectedStoreListingVersionId)
.then((agent) => {
setAgent(agent);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const update: { [key: string]: any } = {};
// Set default values from schema
Object.entries(agent.input_schema?.properties || {}).forEach(
Object.entries(agent.input_schema.properties).forEach(
([key, value]) => {
// Skip if already set
if (state.agentInput && state.agentInput[key]) {
@@ -224,7 +224,7 @@ export default function Page() {
<CardTitle className="font-poppins text-lg">Input</CardTitle>
</CardHeader>
<CardContent className="flex flex-col gap-4">
{Object.entries(agent?.input_schema?.properties || {}).map(
{Object.entries(agent?.input_schema.properties || {}).map(
([key, inputSubSchema]) => (
<div key={key} className="flex flex-col space-y-2">
<label className="flex items-center gap-1 text-sm font-medium">

View File

@@ -1,12 +1,13 @@
"use client";
import { useSearchParams } from "next/navigation";
import { GraphID } from "@/lib/autogpt-server-api/types";
import FlowEditor from "@/components/Flow";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { useEffect } from "react";
import LoadingBox from "@/components/ui/loading";
import { GraphID } from "@/lib/autogpt-server-api/types";
import { useSearchParams } from "next/navigation";
import { Suspense, useEffect } from "react";
export default function BuilderPage() {
function BuilderContent() {
const query = useSearchParams();
const { completeStep } = useOnboarding();
@@ -15,12 +16,20 @@ export default function BuilderPage() {
}, [completeStep]);
const _graphVersion = query.get("flowVersion");
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined
const graphVersion = _graphVersion ? parseInt(_graphVersion) : undefined;
return (
<FlowEditor
className="flow-container"
flowID={query.get("flowID") as GraphID | null ?? undefined}
flowID={(query.get("flowID") as GraphID | null) ?? undefined}
flowVersion={graphVersion}
/>
);
}
export default function BuilderPage() {
return (
<Suspense fallback={<LoadingBox className="h-[80vh]" />}>
<BuilderContent />
</Suspense>
);
}

View File

@@ -1,67 +1,10 @@
import { Navbar } from "@/components/layout/Navbar/Navbar";
import { ReactNode } from "react";
import { Navbar } from "@/components/agptui/Navbar";
import { IconType } from "@/components/ui/icons";
export default function PlatformLayout({ children }: { children: ReactNode }) {
return (
<>
<Navbar
links={[
{
name: "Marketplace",
href: "/marketplace",
},
{
name: "Library",
href: "/library",
},
{
name: "Build",
href: "/build",
},
]}
menuItemGroups={[
{
items: [
{
icon: IconType.Edit,
text: "Edit profile",
href: "/profile",
},
],
},
{
items: [
{
icon: IconType.LayoutDashboard,
text: "Creator Dashboard",
href: "/profile/dashboard",
},
{
icon: IconType.UploadCloud,
text: "Publish an agent",
},
],
},
{
items: [
{
icon: IconType.Settings,
text: "Settings",
href: "/profile/settings",
},
],
},
{
items: [
{
icon: IconType.LogOut,
text: "Log out",
},
],
},
]}
/>
<Navbar />
<main>{children}</main>
</>
);

View File

@@ -1,5 +1,6 @@
"use client";
import { useParams, useRouter } from "next/navigation";
import { useQueryState } from "nuqs";
import React, {
useCallback,
useEffect,
@@ -41,10 +42,11 @@ import {
DialogTitle,
} from "@/components/ui/dialog";
import LoadingBox, { LoadingSpinner } from "@/components/ui/loading";
import { useToast } from "@/components/ui/use-toast";
import { useToast } from "@/components/molecules/Toast/use-toast";
export default function AgentRunsPage(): React.ReactElement {
const { id: agentID }: { id: LibraryAgentID } = useParams();
const [executionId, setExecutionId] = useQueryState("executionId");
const { toast } = useToast();
const router = useRouter();
const api = useBackendAPI();
@@ -202,6 +204,13 @@ export default function AgentRunsPage(): React.ReactElement {
selectPreset,
]);
useEffect(() => {
if (executionId) {
selectRun(executionId as GraphExecutionID);
setExecutionId(null);
}
}, [executionId, selectRun, setExecutionId]);
// Initial load
useEffect(() => {
refreshPageData();
@@ -468,7 +477,7 @@ export default function AgentRunsPage(): React.ReactElement {
}
return (
<div className="container justify-stretch p-0 lg:flex">
<div className="container justify-stretch p-0 pt-16 lg:flex">
{/* Sidebar w/ list of runs */}
{/* TODO: render this below header in sm and md layouts */}
<AgentRunsSelectorList
@@ -512,7 +521,8 @@ export default function AgentRunsPage(): React.ReactElement {
) : selectedView.type == "run" ? (
/* Draft new runs / Create new presets */
<AgentRunDraftView
agent={agent}
graph={graph}
triggerSetupInfo={agent.trigger_setup_info}
onRun={selectRun}
onCreateSchedule={onCreateSchedule}
onCreatePreset={onCreatePreset}
@@ -521,7 +531,8 @@ export default function AgentRunsPage(): React.ReactElement {
) : selectedView.type == "preset" ? (
/* Edit & update presets */
<AgentRunDraftView
agent={agent}
graph={graph}
triggerSetupInfo={agent.trigger_setup_info}
agentPreset={
agentPresets.find((preset) => preset.id == selectedView.id)!
}

View File

@@ -6,6 +6,7 @@ import { useLibraryAgentList } from "./useLibraryAgentList";
export default function LibraryAgentList() {
const {
agentLoading,
agentCount,
allAgents: agents,
isFetchingNextPage,
isSearching,
@@ -18,7 +19,7 @@ export default function LibraryAgentList() {
return (
<>
{/* TODO: We need a new endpoint on backend that returns total number of agents */}
<LibraryActionSubHeader agentCount={agents.length} />
<LibraryActionSubHeader agentCount={agentCount} />
<div className="px-2">
{agentLoading ? (
<div className="flex h-[200px] items-center justify-center">

View File

@@ -56,11 +56,16 @@ export const useLibraryAgentList = () => {
return data.agents;
}) ?? [];
const agentCount = agents?.pages[0]
? (agents.pages[0].data as LibraryAgentResponse).pagination.total_items
: 0;
return {
allAgents,
agentLoading,
isFetchingNextPage,
hasNextPage,
agentCount,
isSearching: isFetching && !isFetchingNextPage,
};
};

View File

@@ -4,7 +4,7 @@ import { z } from "zod";
import { uploadAgentFormSchema } from "./LibraryUploadAgentDialog";
import { usePostV1CreateNewGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
import { useToast } from "@/components/ui/use-toast";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useState } from "react";
import { Graph } from "@/app/api/__generated__/models/graph";
import { sanitizeImportedGraph } from "@/lib/autogpt-server-api";

View File

@@ -1,15 +1,15 @@
"use client";
import Link from "next/link";
import { Alert, AlertDescription } from "@/components/ui/alert";
import {
ArrowBottomRightIcon,
QuestionMarkCircledIcon,
} from "@radix-ui/react-icons";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { LibraryPageStateProvider } from "./components/state-provider";
import LibraryActionHeader from "./components/LibraryActionHeader/LibraryActionHeader";
import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
import { LibraryPageStateProvider } from "./components/state-provider";
/**
* LibraryPage Component
@@ -17,7 +17,7 @@ import LibraryAgentList from "./components/LibraryAgentList/LibraryAgentList";
*/
export default function LibraryPage() {
return (
<main className="container min-h-screen space-y-4 pb-20 sm:px-8 md:px-12">
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<LibraryPageStateProvider>
<LibraryActionHeader />
<LibraryAgentList />

View File

@@ -1,12 +1,12 @@
"use server";
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { verifyTurnstileToken } from "@/lib/turnstile";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import * as Sentry from "@sentry/nextjs";
import { revalidatePath } from "next/cache";
import { redirect } from "next/navigation";
import { z } from "zod";
import * as Sentry from "@sentry/nextjs";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import BackendAPI from "@/lib/autogpt-server-api";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { verifyTurnstileToken } from "@/lib/turnstile";
async function shouldShowOnboarding() {
const api = new BackendAPI();
@@ -23,6 +23,7 @@ export async function login(
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
const supabase = await getServerSupabase();
const api = new BackendAPI();
const isVercelPreview = process.env.VERCEL_ENV === "preview";
if (!supabase) {
redirect("/error");
@@ -30,7 +31,7 @@ export async function login(
// Verify Turnstile token if provided
const success = await verifyTurnstileToken(turnstileToken, "login");
if (!success) {
if (!success && !isVercelPreview) {
return "CAPTCHA verification failed. Please try again.";
}
@@ -38,7 +39,6 @@ export async function login(
const { error } = await supabase.auth.signInWithPassword(values);
if (error) {
console.error("Error logging in:", error);
return error.message;
}
@@ -76,6 +76,11 @@ export async function providerLogin(provider: LoginProvider) {
});
if (error) {
// FIXME: supabase doesn't return the correct error message for this case
if (error.message.includes("P0001")) {
return "not_allowed";
}
console.error("Error logging in", error);
return error.message;
}

View File

@@ -0,0 +1,34 @@
import { AuthCard } from "@/components/auth/AuthCard";
import { Skeleton } from "@/components/ui/skeleton";
export function LoadingLogin() {
return (
<div className="flex h-full min-h-[85vh] flex-col items-center justify-center">
<AuthCard title="">
<div className="w-full space-y-6">
<Skeleton className="mx-auto h-8 w-48" />
<Skeleton className="h-12 w-full rounded-md" />
<div className="flex w-full items-center">
<Skeleton className="h-px flex-1" />
<Skeleton className="mx-3 h-4 w-6" />
<Skeleton className="h-px flex-1" />
</div>
<div className="space-y-2">
<Skeleton className="h-4 w-12" />
<Skeleton className="h-12 w-full rounded-md" />
</div>
<div className="space-y-2">
<Skeleton className="h-4 w-16" />
<Skeleton className="h-12 w-full rounded-md" />
</div>
<Skeleton className="h-16 w-full rounded-md" />
<Skeleton className="h-12 w-full rounded-md" />
<div className="flex justify-center space-x-1">
<Skeleton className="h-4 w-32" />
<Skeleton className="h-4 w-12" />
</div>
</div>
</AuthCard>
</div>
);
}

View File

@@ -1,26 +1,15 @@
"use client";
import {
AuthBottomText,
AuthButton,
AuthCard,
AuthFeedback,
AuthHeader,
GoogleOAuthButton,
PasswordInput,
Turnstile,
} from "@/components/auth";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import LoadingBox from "@/components/ui/loading";
import { Button } from "@/components/atoms/Button/Button";
import { Input } from "@/components/atoms/Input/Input";
import { Link } from "@/components/atoms/Link/Link";
import { AuthCard } from "@/components/auth/AuthCard";
import AuthFeedback from "@/components/auth/AuthFeedback";
import { EmailNotAllowedModal } from "@/components/auth/EmailNotAllowedModal";
import { GoogleOAuthButton } from "@/components/auth/GoogleOAuthButton";
import Turnstile from "@/components/auth/Turnstile";
import { Form, FormField } from "@/components/ui/form";
import { getBehaveAs } from "@/lib/utils";
import Link from "next/link";
import { LoadingLogin } from "./components/LoadingLogin";
import { useLoginPage } from "./useLoginPage";
export default function LoginPage() {
@@ -30,17 +19,20 @@ export default function LoginPage() {
turnstile,
captchaKey,
isLoading,
isCloudEnv,
isLoggedIn,
isCloudEnv,
shouldNotRenderCaptcha,
isUserLoading,
isGoogleLoading,
showNotAllowedModal,
isSupabaseAvailable,
handleSubmit,
handleProviderLogin,
handleCloseNotAllowedModal,
} = useLoginPage();
if (isUserLoading || isLoggedIn) {
return <LoadingBox className="h-[80vh]" />;
return <LoadingLogin />;
}
if (!isSupabaseAvailable) {
@@ -52,99 +44,93 @@ export default function LoginPage() {
}
return (
<AuthCard className="mx-auto">
<AuthHeader>Login to your account</AuthHeader>
<div className="flex h-full min-h-[85vh] flex-col items-center justify-center py-10">
<AuthCard title="Login to your account">
<Form {...form}>
<form onSubmit={handleSubmit} className="flex w-full flex-col gap-1">
<FormField
control={form.control}
name="email"
render={({ field }) => (
<Input
id={field.name}
label="Email"
placeholder="m@example.com"
type="email"
autoComplete="username"
className="w-full"
error={form.formState.errors.email?.message}
{...field}
/>
)}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<Input
id={field.name}
label="Password"
placeholder="•••••••••••••••••••••"
type="password"
autoComplete="current-password"
error={form.formState.errors.password?.message}
hint={
<Link variant="secondary" href="/reset-password">
Forgot password?
</Link>
}
{...field}
/>
)}
/>
{isCloudEnv ? (
<>
<div className="mb-6">
{/* Turnstile CAPTCHA Component */}
{shouldNotRenderCaptcha ? null : (
<Turnstile
key={captchaKey}
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
setWidgetId={turnstile.setWidgetId}
action="login"
shouldRender={turnstile.shouldRender}
/>
)}
<Button
variant="primary"
loading={isLoading}
type="submit"
className="mt-6 w-full"
>
{isLoading ? "Logging in..." : "Login"}
</Button>
</form>
{isCloudEnv ? (
<GoogleOAuthButton
onClick={() => handleProviderLogin("google")}
isLoading={isGoogleLoading}
disabled={isLoading}
/>
</div>
<div className="mb-6 flex items-center">
<div className="flex-1 border-t border-gray-300"></div>
<span className="mx-3 text-sm text-gray-500">or</span>
<div className="flex-1 border-t border-gray-300"></div>
</div>
</>
) : null}
<Form {...form}>
<form onSubmit={handleSubmit}>
<FormField
control={form.control}
name="email"
render={({ field }) => (
<FormItem className="mb-6">
<FormLabel>Email</FormLabel>
<FormControl>
<Input
placeholder="m@example.com"
{...field}
type="email" // Explicitly specify email type
autoComplete="username" // Added for password managers
/>
</FormControl>
<FormMessage />
</FormItem>
)}
) : null}
<AuthFeedback
type="login"
message={feedback}
isError={!!feedback}
behaveAs={getBehaveAs()}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<FormItem className="mb-6">
<FormLabel className="flex w-full items-center justify-between">
<span>Password</span>
<Link
href="/reset-password"
className="text-sm font-normal leading-normal text-black underline"
>
Forgot your password?
</Link>
</FormLabel>
<FormControl>
<PasswordInput
{...field}
autoComplete="current-password" // Added for password managers
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
{/* Turnstile CAPTCHA Component */}
<Turnstile
key={captchaKey}
siteKey={turnstile.siteKey}
onVerify={turnstile.handleVerify}
onExpire={turnstile.handleExpire}
onError={turnstile.handleError}
setWidgetId={turnstile.setWidgetId}
action="login"
shouldRender={turnstile.shouldRender}
/>
<AuthButton isLoading={isLoading} type="submit">
Login
</AuthButton>
</form>
<AuthFeedback
type="login"
message={feedback}
isError={!!feedback}
behaveAs={getBehaveAs()}
</Form>
<AuthCard.BottomText
text="Don't have an account?"
link={{ text: "Sign up", href: "/signup" }}
/>
</Form>
<AuthBottomText
text="Don't have an account?"
linkText="Sign up"
href="/signup"
</AuthCard>
<EmailNotAllowedModal
isOpen={showNotAllowedModal}
onClose={handleCloseNotAllowedModal}
/>
</AuthCard>
</div>
);
}

View File

@@ -1,23 +1,26 @@
import { useTurnstile } from "@/hooks/useTurnstile";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { BehaveAs, getBehaveAs } from "@/lib/utils";
import { loginFormSchema, LoginProvider } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import { useForm } from "react-hook-form";
import { login, providerLogin } from "./actions";
import z from "zod";
import { BehaveAs } from "@/lib/utils";
import { getBehaveAs } from "@/lib/utils";
import { login, providerLogin } from "./actions";
import { useToast } from "@/components/molecules/Toast/use-toast";
export function useLoginPage() {
const { supabase, user, isUserLoading } = useSupabase();
const [feedback, setFeedback] = useState<string | null>(null);
const [captchaKey, setCaptchaKey] = useState(0);
const router = useRouter();
const { toast } = useToast();
const [isLoading, setIsLoading] = useState(false);
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = getBehaveAs() === BehaveAs.CLOUD;
const isVercelPreview = process.env.NEXT_PUBLIC_VERCEL_ENV === "preview";
const turnstile = useTurnstile({
action: "login",
@@ -25,6 +28,8 @@ export function useLoginPage() {
resetOnError: true,
});
const shouldNotRenderCaptcha = isVercelPreview || turnstile.verified;
const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema),
defaultValues: {
@@ -44,29 +49,53 @@ export function useLoginPage() {
async function handleProviderLogin(provider: LoginProvider) {
setIsGoogleLoading(true);
if (!turnstile.verified && !isVercelPreview) {
toast({
title: "Please complete the CAPTCHA challenge.",
variant: "info",
});
setIsGoogleLoading(false);
resetCaptcha();
return;
}
try {
const error = await providerLogin(provider);
if (error) throw error;
setFeedback(null);
} catch (error) {
resetCaptcha();
setFeedback(JSON.stringify(error));
} finally {
setIsGoogleLoading(false);
const errorString = JSON.stringify(error);
if (errorString.includes("not_allowed")) {
setShowNotAllowedModal(true);
} else {
setFeedback(errorString);
}
}
}
async function handleLogin(data: z.infer<typeof loginFormSchema>) {
setIsLoading(true);
if (!turnstile.verified) {
setFeedback("Please complete the CAPTCHA challenge.");
if (!turnstile.verified && !isVercelPreview) {
toast({
title: "Please complete the CAPTCHA challenge.",
variant: "info",
});
setIsLoading(false);
resetCaptcha();
return;
}
if (data.email.includes("@agpt.co")) {
setFeedback("Please use Google SSO to login using an AutoGPT email.");
toast({
title: "Please use Google SSO to login using an AutoGPT email.",
variant: "default",
});
setIsLoading(false);
resetCaptcha();
return;
@@ -76,7 +105,11 @@ export function useLoginPage() {
await supabase?.auth.refreshSession();
setIsLoading(false);
if (error) {
setFeedback(error);
toast({
title: error,
variant: "destructive",
});
resetCaptcha();
// Always reset the turnstile on any error
turnstile.reset();
@@ -94,9 +127,12 @@ export function useLoginPage() {
isLoading,
isCloudEnv,
isUserLoading,
shouldNotRenderCaptcha,
isGoogleLoading,
showNotAllowedModal,
isSupabaseAvailable: !!supabase,
handleSubmit: form.handleSubmit(handleLogin),
handleProviderLogin,
handleCloseNotAllowedModal: () => setShowNotAllowedModal(false),
};
}

Some files were not shown because too many files have changed in this diff Show More