Compare commits

..

21 Commits

Author SHA1 Message Date
Otto
9b20f4cd13 refactor: simplify ExecutionQueue docstrings and move test file
- Trim verbose BUG FIX docstring to concise 3-line note
- Remove redundant method docstrings (add, get, empty)
- Move test file to backend/data/ with proper pytest conventions
- Add note about ProcessPoolExecutor migration for future devs

Co-authored-by: Zamil Majdy <majdyz@users.noreply.github.com>
2026-02-08 16:11:35 +00:00
Nikhil Bhagat
a3d0f9cbd2 fix(backend): format test_execution_queue.py and remove unused variable 2025-12-14 19:37:29 +05:45
Nikhil Bhagat
02ddb51446 Added test_execution_queue.py and test the execution part and the test got passed 2025-12-14 19:05:14 +05:45
Nikhil Bhagat
750e096f15 fix(backend): replace multiprocessing.Manager().Queue() with queue.Queue()
ExecutionQueue was unnecessarily using multiprocessing.Manager().Queue() which
spawns a subprocess for IPC. Since ExecutionQueue is only accessed from threads
within the same process, queue.Queue() is sufficient and more efficient.

- Eliminates unnecessary subprocess spawning per graph execution
- Removes IPC overhead for queue operations
- Prevents potential resource leaks from Manager processes
- Improves scalability for concurrent graph executions
2025-12-14 19:04:14 +05:45
Krzysztof Czerwinski
ff5c8f324b Merge branch 'master' into dev 2025-12-12 22:26:39 +09:00
Krzysztof Czerwinski
f121a22544 hotfix: update next (#11612)
Update next to `15.4.10`
2025-12-12 13:42:36 +01:00
Zamil Majdy
71157bddd7 feat(backend): add agent mode support to SmartDecisionMakerBlock with autonomous tool execution loops (#11547)
## Summary

<img width="2072" height="1836" alt="image"
src="https://github.com/user-attachments/assets/9d231a77-6309-46b9-bc11-befb5d8e9fcc"
/>

**🚀 Major Feature: Agent Mode Support**

Adds autonomous agent mode to SmartDecisionMakerBlock, enabling it to
execute tools directly in loops until tasks are completed, rather than
just yielding tool calls for external execution.

##  **Key New Features**

### 🤖 **Agent Mode with Tool Execution Loops**
- **New `agent_mode_max_iterations` parameter** controls execution
behavior:
  - `0` = Traditional mode (single LLM call, yield tool calls)
  - `1+` = Agent mode with iteration limit
  - `-1` = Infinite agent mode (loop until finished)

### 🔄 **Autonomous Tool Execution**  
- **Direct tool execution** instead of yielding for external handling
- **Multi-iteration loops** with conversation state management
- **Automatic completion detection** when LLM stops making tool calls
- **Iteration limit handling** with graceful completion messages

### 🏗️ **Proper Database Operations**
- **Replace manual execution ID generation** with proper
`upsert_execution_input`/`upsert_execution_output`
- **Real NodeExecutionEntry objects** from database results
- **Proper execution status management**: QUEUED → RUNNING →
COMPLETED/FAILED

### 🔧 **Enhanced Type Safety**
- **Pydantic models** replace TypedDict: `ToolInfo` and
`ExecutionParams`
- **Runtime validation** with better error messages
- **Improved developer experience** with IDE support

## 🔧 **Technical Implementation**

### Agent Mode Flow:
```python
# Agent mode enabled with iterations
if input_data.agent_mode_max_iterations != 0:
    async for result in self._execute_tools_agent_mode(...):
        yield result  # "conversations", "finished"
    return

# Traditional mode (existing behavior)  
# Single LLM call + yield tool calls for external execution
```

### Tool Execution with Database Operations:
```python
# Before: Manual execution IDs
tool_exec_id = f"{node_exec_id}_tool_{sink_node_id}_{len(input_data)}"

# After: Proper database operations
node_exec_result, final_input_data = await db_client.upsert_execution_input(
    node_id=sink_node_id,
    graph_exec_id=execution_params.graph_exec_id,
    input_name=input_name, 
    input_data=input_value,
)
```

### Type Safety with Pydantic:
```python
# Before: Dict access prone to errors
execution_params["user_id"]  

# After: Validated model access
execution_params.user_id  # Runtime validation + IDE support
```

## 🧪 **Comprehensive Test Coverage**

- **Agent mode execution tests** with multi-iteration scenarios
- **Database operation verification** 
- **Type safety validation**
- **Backward compatibility** for traditional mode
- **Enhanced dynamic fields tests**

## 📊 **Usage Examples**

### Traditional Mode (Existing Behavior):
```python
SmartDecisionMakerBlock.Input(
    prompt="Search for keywords",
    agent_mode_max_iterations=0  # Default
)
# → Yields tool calls for external execution
```

### Agent Mode (New Feature):
```python  
SmartDecisionMakerBlock.Input(
    prompt="Complete this task using available tools",
    agent_mode_max_iterations=5  # Max 5 iterations
)
# → Executes tools directly until task completion or iteration limit
```

### Infinite Agent Mode:
```python
SmartDecisionMakerBlock.Input(
    prompt="Analyze and process this data thoroughly", 
    agent_mode_max_iterations=-1  # No limit, run until finished
)
# → Executes tools autonomously until LLM indicates completion
```

##  **Backward Compatibility**

- **Zero breaking changes** to existing functionality
- **Traditional mode remains default** (`agent_mode_max_iterations=0`)
- **All existing tests pass**
- **Same API for tool definitions and execution**

This transforms the SmartDecisionMakerBlock from a simple tool call
generator into a powerful autonomous agent capable of complex multi-step
task execution! 🎯

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-12 09:58:06 +00:00
Bentlybro
152e747ea6 Merge branch 'master' into dev 2025-12-11 14:40:01 +00:00
Reinier van der Leer
4d4741d558 fix(frontend/library): Transition from empty tasks view on task init (#11600)
- Resolves #11599

### Changes 🏗️

- Manually update item counts when initiating a task from `EmptyTasks`
view
- Other improvements made while debugging

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] `NewAgentLibraryView` transitions to full layout when a first task
is created
- [x] `NewAgentLibraryView` transitions to full layout when a first
trigger is set up
2025-12-11 11:13:53 +00:00
Krzysztof Czerwinski
bd37fe946d feat(platform): Builder search history (#11457)
Preserve user searches in the new builder and cache search results for
more efficiency.
Search is saved, so the user can see their previous searches.

### Changes 🏗️

- Add `BuilderSearch` column&migration to save user search (with all
filters)
- Builder `db.py` now caches all search results using `@cached` and
returns paginated results, so following pages are returned much quicker
- Score and sort results
- Update models&routes
- Update frontend, so it works properly with modified endpoints
- Frontend: store `serachId` and use it for subsequent searches, so we
don't save partial searches (e.g. "b", "bl", ..., "block"). Search id is
reset when user clears the search field.
- Add clickable chips to the Suggestions builder tab
- Add `HorizontalScroll` component (chips use it)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Search works and is cached
  - [x] Search sorts results
  - [x] Searches are preserved properly

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-12-10 17:32:17 +00:00
Nicholas Tindle
7ff282c908 fix(frontend): Disable single dollar sign LaTeX mode in markdown rend… (#11598)
Single dollar signs ($10, $variable) are commonly used in content and
were being incorrectly interpreted as inline LaTeX math delimiters. This
change disables that behavior while keeping double dollar sign ($$...$$)
math blocks working.
## Changes 🏗️
• Configure remarkMath plugin with singleDollarTextMath: false in
MarkdownRenderer.tsx
• Double dollar sign display math ($$...$$) continues to work as
expected
• Single dollar signs are no longer interpreted as inline math
delimiters
## Checklist 📋
For code changes:
	-[x]	I have clearly listed my changes in the PR description
	-[x] I have made a test plan
	-[x] I have tested my changes according to the test plan:
-[x] Verify content with dollar amounts (e.g., “$100”) renders as plain
text
-[x] Verify double dollar sign math blocks ($$x^2$$) still render as
LaTeX
-[x] Verify other markdown features (code blocks, tables, links) still
work correctly

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-10 17:10:52 +00:00
Reinier van der Leer
117bb05438 fix(frontend/library): Fix trigger UX flows in v3 library (#11589)
- Resolves #11586
- Follow-up to #11580

### Changes 🏗️

- Fix logic to include manual triggers as a possibility
- Fix input render logic to use trigger setup schema if applicable
- Fix rendering payload input for externally triggered runs
- Amend `RunAgentModal` to load preset inputs+credentials if selected
- Amend `SelectedTemplateView` to use modified input for run (if
applicable)
- Hide non-applicable buttons in `SelectedRunView` for externally
triggered runs
- Implement auto-navigation to `SelectedTriggerView` on trigger setup

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Can set up manual triggers
    - [x] Navigates to trigger view after setup
  - [x] Can set up automatic triggers
  - [x] Can create templates from runs
  - [x] Can run templates
  - [x] Can run templates with modified input
2025-12-10 15:52:02 +00:00
Nicholas Tindle
979d7c3b74 feat(blocks): Add 4 new GitHub webhook trigger blocks (#11588)
I want to be able to automate some actions on social media or our
sevrver in response to actions from discord


<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
Add trigger blocks for common GitHub events to enable OSS automation:
- GithubReleaseTriggerBlock: Trigger on release events (published, etc.)
- GithubStarTriggerBlock: Trigger on star events for milestone
celebrations
- GithubIssuesTriggerBlock: Trigger on issue events for
triage/notifications
- GithubDiscussionTriggerBlock: Trigger on discussion events for Q&A
sync
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test Stars
  - [x] Test Discussions
  - [x] Test Issues
  - [x] Test Release

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-09 21:25:43 +00:00
Nicholas Tindle
95200b67f8 feat(blocks): add many new spreadsheet blocks (#11574)
<!-- Clearly explain the need for these changes: -->
We have lots we want to do with google sheets and we don't want a lack
of blocks to be a limiter so I pre-ddi a lot of blocks!

### Changes 🏗️
Adds 24 new blocks for google sheets (tested and working)
```
|-----|-------------------------------------------|----------------------------------------|
|  1  | GoogleSheetsFilterRowsBlock               | Filter rows based on column conditions |  |
|  2  | GoogleSheetsLookupRowBlock                | VLOOKUP-style row lookup               |  |
|  3  | GoogleSheetsDeleteRowsBlock               | Delete rows from a sheet               |  |
|  4  | GoogleSheetsGetColumnBlock                | Get data from a specific column        |  |
|  5  | GoogleSheetsSortBlock                     | Sort sheet data                        |  | 
|  6  | GoogleSheetsGetUniqueValuesBlock          | Get unique values from a column        |  | 
|  7  | GoogleSheetsInsertRowBlock                | Insert rows into a sheet               |  |
|  8  | GoogleSheetsAddColumnBlock                | Add a new column                       |  |
|  9  | GoogleSheetsGetRowCountBlock              | Get the number of rows                 |  |
| 10  | GoogleSheetsRemoveDuplicatesBlock         | Remove duplicate rows                  |  | 
| 11  | GoogleSheetsUpdateRowBlock                | Update an existing row                 |  |
| 12  | GoogleSheetsGetRowBlock                   | Get a specific row by index            |  |
| 13  | GoogleSheetsDeleteColumnBlock             | Delete a column                        |  |
| 14  | GoogleSheetsCreateNamedRangeBlock         | Create a named range                   |  | 
| 15  | GoogleSheetsListNamedRangesBlock          | List all named ranges                  |  | 
| 16  | GoogleSheetsAddDropdownBlock              | Add dropdown validation to cells       |  | 
| 17  | GoogleSheetsCopyToSpreadsheetBlock        | Copy sheet to another spreadsheet      |  |
| 18  | GoogleSheetsProtectRangeBlock             | Protect a range from editing           |  |
| 19  | GoogleSheetsExportCsvBlock                | Export sheet as CSV                    |  |
| 20  | GoogleSheetsImportCsvBlock                | Import CSV data                        |  |
| 21  | GoogleSheetsAddNoteBlock                  | Add notes to cells                     |  | 
| 22  | GoogleSheetsGetNotesBlock                 | Get notes from cells                   |  | 
| 23  | GoogleSheetsShareSpreadsheetBlock         | Share spreadsheet with users           |  | 
| 24  | GoogleSheetsSetPublicAccessBlock          | Set public access permissions          |  | 
```


<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Tested using the attached agent 
[super test for
spreadsheets_v9.json](https://github.com/user-attachments/files/24041582/super.test.for.spreadsheets_v9.json)


<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Introduces a large suite of Google Sheets blocks for row/column ops,
filtering/sorting/lookup, CSV import/export, notes, named ranges,
protections, sheet copy, and sharing/public access, plus refactors
append to a simpler single-row append.
> 
> - **Google Sheets blocks (new)**:
> - **Data ops**: `GoogleSheetsFilterRowsBlock`,
`GoogleSheetsLookupRowBlock`, `GoogleSheetsDeleteRowsBlock`,
`GoogleSheetsGetColumnBlock`, `GoogleSheetsSortBlock`,
`GoogleSheetsGetUniqueValuesBlock`, `GoogleSheetsInsertRowBlock`,
`GoogleSheetsAddColumnBlock`, `GoogleSheetsGetRowCountBlock`,
`GoogleSheetsRemoveDuplicatesBlock`, `GoogleSheetsUpdateRowBlock`,
`GoogleSheetsGetRowBlock`, `GoogleSheetsDeleteColumnBlock`.
> - **Named ranges & validation**: `GoogleSheetsCreateNamedRangeBlock`,
`GoogleSheetsListNamedRangesBlock`, `GoogleSheetsAddDropdownBlock`.
> - **Sheet/admin**: `GoogleSheetsCopyToSpreadsheetBlock`,
`GoogleSheetsProtectRangeBlock`.
> - **CSV & notes**: `GoogleSheetsExportCsvBlock`,
`GoogleSheetsImportCsvBlock`, `GoogleSheetsAddNoteBlock`,
`GoogleSheetsGetNotesBlock`.
> - **Sharing**: `GoogleSheetsShareSpreadsheetBlock`,
`GoogleSheetsSetPublicAccessBlock`.
> - **Refactor**:
> - Rename and simplify append: `GoogleSheetsAppendRowBlock` (replaces
multi-row/dict input with single `row`), fixed insert option to
`INSERT_ROWS` and streamlined response.
> - **Utilities/Enums**:
> - Add helpers (`_column_letter_to_index`, `_index_to_column_letter`,
`_apply_filter`) and enums (`FilterOperator`, `SortOrder`, `ShareRole`,
`PublicAccessRole`).
> - Drive/Sheets service builders and file validation reused across new
blocks.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6e9e2f4024. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-09 17:28:22 +00:00
Abhimanyu Yadav
f8afc6044e fix(frontend): prevent file upload buttons from triggering form submission (#11576)
<!-- Clearly explain the need for these changes: -->

In the File Widget, the upload button was incorrectly behaving like a
submit button. When users clicked it, the rjsf library immediately
triggered form validation and displayed validation errors, even though
the user was only trying to upload a file.

This happened because HTML buttons inside a form default to
`type="submit"`, which triggers form submission on click. By explicitly
setting `type="button"` on all file-related buttons, we prevent them
from submitting the form while still allowing them to trigger the file
input dialog.

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

- Added `type="button"` attribute to the clear button in the compact
variant
- Added `type="button"` attribute to the upload button in the compact
variant
- Added `type="button"` attribute to the "Browse File" button in the
default variant

This ensures that clicking any of these buttons only triggers the
intended file selection/upload action without causing unwanted form
validation or submission.

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Tested clicking the upload button in a form with File Widget -
form should not submit or show validation errors
- [x] Tested clicking the clear button - should clear the file without
triggering form validation
- [x] Tested clicking the "Browse File" button - should open file dialog
without triggering form validation
- [x] Verified file upload functionality still works correctly after
selecting a file
2025-12-09 15:54:17 +00:00
Abhimanyu Yadav
7edf01777e fix(frontend): sync flowVersion to URL when loading graph from Library (#11585)
<!-- Clearly explain the need for these changes: -->

When opening a graph from the Library, the `flowVersion` query parameter
was not being set in the URL. This caused issues when the graph data
didn't contain an internal `graphVersion`, resulting in the builder
failing and graphs breaking when running.

The `useGetV1GetSpecificGraph` hook relies on the `flowVersion` query
parameter to fetch the correct graph version. Without it being properly
set in the URL, the graph loading logic fails when the version
information is missing from the graph data itself.

### Changes 🏗️

- Added `setQueryStates` to the `useQueryStates` hook return value in
`useFlow.ts`
- Added logic to sync `flowVersion` to the URL query parameters when a
graph is loaded
- When `graph.version` is available, it now updates the `flowVersion`
query parameter in the URL (defaults to `1` if version is undefined)

This ensures the URL stays in sync with the loaded graph's version,
preventing builder failures and execution issues.

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Open a graph from the Library that doesn't have a version in the
URL
- [x] Verify that `flowVersion` is automatically added to the URL query
parameters
  - [x] Verify that the graph loads correctly and can be executed
- [x] Verify that graphs with existing `flowVersion` in URL continue to
work correctly
- [x] Verify that graphs opened from Library with version information
sync correctly
2025-12-09 15:26:45 +00:00
Ubbe
c9681f5d44 fix(frontend): library page adjustments (#11587)
## Changes 🏗️

### Adjust layout and styles on mobile 📱 

<img width="448" height="843" alt="Screenshot 2025-12-09 at 22 53 14"
src="https://github.com/user-attachments/assets/159bdf4f-e6b2-42f5-8fdf-25f8a62c62d1"
/>

### Make the sidebar cards have contextual actions

<img width="486" height="243" alt="Screenshot 2025-12-09 at 22 53 27"
src="https://github.com/user-attachments/assets/2f530168-3217-47c4-b08d-feccbb9e9152"
/>

Depending on the card type, different type of actions are shown...

### Make buttons in "About agent" card do something

<img width="344" height="346" alt="Screenshot 2025-12-09 at 22 54 01"
src="https://github.com/user-attachments/assets/47181f80-1f68-4ef1-aecc-bbadc7cc9c44"
/>

### Other

- Hide `Schedule` button for agents with trigger run type
- Adjust secondary button background colour...
- Make drawer content scrollable on mobile 

## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run locally and test the above

Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2025-12-09 23:17:44 +07:00
Abhimanyu Yadav
1305325813 fix(frontend): preserve button shape in credential select when content is long (#11577)
<!-- Clearly explain the need for these changes: -->

When the content inside the credential select dropdown becomes too long,
the adjacent link buttons lose their rounded shape and appear squarish.
This happens when the text stretches the container or affects the layout
of the buttons.

The issue occurs because the button's width can shrink below its
intended size when the flex container is stretched by long credential
names. By adding an explicit minimum width constraint with `!min-w-8`,
we ensure the button maintains its proper dimensions and rounded
appearance regardless of the select dropdown's content length.

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

- Added `!min-w-8` to the external link button's className in
`SelectCredential` component to enforce a minimum width of 2rem (8 *
0.25rem)
- This ensures the button maintains its rounded shape even when the
adjacent select dropdown contains long credential names

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Tested credential select with short credential names - button
should maintain rounded shape
- [x] Tested credential select with very long credential names (e.g.,
long provider names, usernames, and hosts) - button should maintain
rounded shape
2025-12-09 15:26:22 +00:00
Abhimanyu Yadav
4f349281bd feat(frontend): switch copied graph storage from local storage to clipboard (#11578)
### Changes 🏗️

This PR migrates the copy/paste functionality for graph nodes and edges
from local storage to the Clipboard API. This change addresses storage
limitations and enables cross-tab copying.


https://github.com/user-attachments/assets/6ef55713-ca5b-4562-bb54-4c12db241d30


**Key changes:**
- Replaced `localStorage` with `navigator.clipboard` API for copy/paste
operations
- Added `CLIPBOARD_PREFIX` constant (`"autogpt-flow-data:"`) to identify
our clipboard data and prevent conflicts with other clipboard content
- Added toast notifications to provide user feedback when copying nodes
- Added error handling for clipboard read/write operations with console
error logging
- Removed dependency on `@/services/storage/local-storage` for copied
flow data
- Updated `useCopyPaste` hook to use async clipboard operations with
proper promise handling

**Benefits:**
-  Removes local storage size limitations (5-10MB)
-  Enables copying nodes between browser tabs/windows
-  Provides better user feedback through toast notifications
-  More standard approach using native browser Clipboard API

### Checklist 📋

#### For code changes:

- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Select one or more nodes in the flow editor
  - [x] Press `Ctrl+C` (or `Cmd+C` on Mac) to copy nodes
- [x] Verify toast notification appears showing "Copied successfully"
with node count
  - [x] Press `Ctrl+V` (or `Cmd+V` on Mac) to paste nodes
  - [x] Verify nodes are pasted at viewport center with new unique IDs
  - [x] Verify edges between copied nodes are also pasted correctly
- [x] Test copying nodes in one browser tab and pasting in another tab
(should work)
- [x] Test copying non-flow data (e.g., text) and verify paste doesn't
interfere with flow editor
2025-12-09 15:26:12 +00:00
Nicholas Tindle
c4eb7edb65 feat(platform): Improve Google Sheets/Drive integration with unified credentials (#11520)
Simplifies and improves the Google Sheets/Drive integration by merging
credentials with the file picker and using narrower OAuth scopes.

### Changes 🏗️

- Merge Google credentials and file picker into a single unified input
field for better UX
- Create spreadsheets using Drive API instead of Sheets API for proper
scope support
- Simplify Google Drive OAuth scope to only use `drive.file` (narrowest
permission needed)
- Clean up unused imports (NormalizedPickedFile)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Test creating a new Google Spreadsheet with
GoogleSheetsCreateSpreadsheetBlock
- [x] Test reading from existing spreadsheets with GoogleSheetsReadBlock
  - [x] Test writing to spreadsheets with GoogleSheetsWriteBlock
  - [x] Verify OAuth flow works with simplified scopes
  - [x] Verify file picker works with merged credentials field

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Unifies Google Drive picker and credentials with auto-credentials
across backend and frontend, updates all Sheets blocks and execution to
use it, and adds Drive-based spreadsheet creation plus supporting tests
and UI fixes.
> 
> - **Backend**:
> - **Google Drive model/field**: Introduce `GoogleDriveFile` (with
`_credentials_id`) and `GoogleDriveFileField()` for unified auth+picker
(`backend/blocks/google/_drive.py`).
> - **Sheets blocks**: Replace `GoogleDrivePickerField` and explicit
credentials with `GoogleDriveFileField` across all Sheets blocks;
preserve and emit credentials for chaining; add Drive service; create
spreadsheets via Drive API then manage via Sheets API.
> - **IO block**: Add `AgentGoogleDriveFileInputBlock` providing a Drive
picker input.
> - **Execution**: Support auto-generated credentials via
`BlockSchema.get_auto_credentials_fields()`; acquire/release multiple
credential locks; pass creds by `credentials_kwarg`
(`executor/manager.py`, `data/block.py`, `util/test.py`).
> - **Tests**: Add validation tests for duplicate/unique
`auto_credentials.kwarg_name` and defaults.
> - **Frontend**:
> - **Picker**: Enhance Google Drive picker to require/use saved
platform credentials, pass `_credentials_id`, validate scopes, and
manage dialog z-index/interaction; expose `requirePlatformCredentials`.
> - **UI**: Update dialogs/CSS to keep Google picker on top and prevent
overlay interactions.
> - **Types**: Extend `GoogleDrivePickerConfig` with `auto_credentials`
and related typings.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
7d25534def. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
2025-12-05 09:44:38 -06:00
Nicholas Tindle
3f690ea7b8 fix(platform/frontend): security upgrade next from 15.4.7 to 15.4.8 (#11536)
![snyk-top-banner](https://res.cloudinary.com/snyk/image/upload/r-d/scm-platform/snyk-pull-requests/pr-banner-default.svg)

### Snyk has created this PR to fix 1 vulnerabilities in the yarn
dependencies of this project.

#### Snyk changed the following file(s):

- `autogpt_platform/frontend/package.json`


#### Note for
[zero-installs](https://yarnpkg.com/features/zero-installs) users

If you are using the Yarn feature
[zero-installs](https://yarnpkg.com/features/zero-installs) that was
introduced in Yarn V2, note that this PR does not update the
`.yarn/cache/` directory meaning this code cannot be pulled and
immediately developed on as one would expect for a zero-install project
- you will need to run `yarn` to update the contents of the
`./yarn/cache` directory.
If you are not using zero-install you can ignore this as your flow
should likely be unchanged.



<details>
<summary>⚠️ <b>Warning</b></summary>

```
Failed to update the yarn.lock, please update manually before merging.
```

</details>



#### Vulnerabilities that will be fixed with an upgrade:

|  | Issue |  
:-------------------------:|:-------------------------
![critical
severity](https://res.cloudinary.com/snyk/image/upload/w_20,h_20/v1561977819/icon/c.png
'critical severity') | Arbitrary Code Injection
<br/>[SNYK-JS-NEXT-14173355](https://snyk.io/vuln/SNYK-JS-NEXT-14173355)




---

> [!IMPORTANT]
>
> - Check the changes in this PR to ensure they won't cause issues with
your project.
> - Max score is 1000. Note that the real score may have changed since
the PR was raised.
> - This PR was automatically created by Snyk using the credentials of a
real user.

---

**Note:** _You are seeing this because you or someone else with access
to this repository has authorized Snyk to open fix PRs._

For more information: <img
src="https://api.segment.io/v1/pixel/track?data=eyJ3cml0ZUtleSI6InJyWmxZcEdHY2RyTHZsb0lYd0dUcVg4WkFRTnNCOUEwIiwiYW5vbnltb3VzSWQiOiJhNDQzN2JlZC0wMjYxLTRhZmMtYmQxOS1hMTUwY2RhMDE3ZDciLCJldmVudCI6IlBSIHZpZXdlZCIsInByb3BlcnRpZXMiOnsicHJJZCI6ImE0NDM3YmVkLTAyNjEtNGFmYy1iZDE5LWExNTBjZGEwMTdkNyJ9fQ=="
width="0" height="0"/>
🧐 [View latest project
report](https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source&#x3D;github&amp;utm_medium&#x3D;referral&amp;page&#x3D;fix-pr)
📜 [Customise PR
templates](https://docs.snyk.io/scan-using-snyk/pull-requests/snyk-fix-pull-or-merge-requests/customize-pr-templates?utm_source=github&utm_content=fix-pr-template)
🛠 [Adjust project
settings](https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source&#x3D;github&amp;utm_medium&#x3D;referral&amp;page&#x3D;fix-pr/settings)
📚 [Read about Snyk's upgrade
logic](https://docs.snyk.io/scan-with-snyk/snyk-open-source/manage-vulnerabilities/upgrade-package-versions-to-fix-vulnerabilities?utm_source=github&utm_content=fix-pr-template)

---

**Learn how to fix vulnerabilities with free interactive lessons:**

🦉 [Arbitrary Code
Injection](https://learn.snyk.io/lesson/insecure-deserialization/?loc&#x3D;fix-pr)

[//]: #
'snyk:metadata:{"breakingChangeRiskLevel":null,"FF_showPullRequestBreakingChanges":false,"FF_showPullRequestBreakingChangesWebSearch":false,"customTemplate":{"variablesUsed":[],"fieldsUsed":[]},"dependencies":[{"name":"next","from":"15.4.7","to":"15.4.8"}],"env":"prod","issuesToFix":["SNYK-JS-NEXT-14173355"],"prId":"a4437bed-0261-4afc-bd19-a150cda017d7","prPublicId":"a4437bed-0261-4afc-bd19-a150cda017d7","packageManager":"yarn","priorityScoreList":[null],"projectPublicId":"3d924968-0cf3-4767-9609-501fa4962856","projectUrl":"https://app.snyk.io/org/significant-gravitas/project/3d924968-0cf3-4767-9609-501fa4962856?utm_source=github&utm_medium=referral&page=fix-pr","prType":"fix","templateFieldSources":{"branchName":"default","commitMessage":"default","description":"default","title":"default"},"templateVariants":["updated-fix-title","pr-warning-shown"],"type":"auto","upgrade":["SNYK-JS-NEXT-14173355"],"vulns":["SNYK-JS-NEXT-14173355"],"patch":[],"isBreakingChange":false,"remediationStrategy":"vuln"}'

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> Upgrades Next.js from 15.4.7 to 15.4.8 in the frontend and updates
lockfile/transitive references accordingly.
> 
> - **Dependencies**:
> - Bump `next` to `15.4.8` in `autogpt_platform/frontend/package.json`.
> - Update lockfile to align, including `@next/*` SWC binaries and
packages that peer-depend on `next` (e.g., `@sentry/nextjs`,
`@storybook/nextjs`, `@vercel/*`, `geist`, `nuqs`,
`@next/third-parties`).
> - Minor transitive tweak: `sharp` dependency `semver` updated to
`7.7.3`.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
e7741cbfb5. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
Co-authored-by: Bentlybro <Github@bentlybro.com>
Co-authored-by: Abhimanyu Yadav <122007096+Abhi1992002@users.noreply.github.com>
2025-12-05 09:44:12 -06:00
134 changed files with 9799 additions and 15979 deletions

View File

@@ -0,0 +1,108 @@
{
"action": "created",
"discussion": {
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"category": {
"id": 12345678,
"node_id": "DIC_kwDOJKSTjM4CXXXX",
"repository_id": 614765452,
"emoji": ":pray:",
"name": "Q&A",
"description": "Ask the community for help",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2023-03-16T09:21:07Z",
"slug": "q-a",
"is_answerable": true
},
"answer_html_url": null,
"answer_chosen_at": null,
"answer_chosen_by": null,
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/discussions/9999",
"id": 5000000001,
"node_id": "D_kwDOJKSTjM4AYYYY",
"number": 9999,
"title": "How do I configure custom blocks?",
"user": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
},
"state": "open",
"state_reason": null,
"locked": false,
"comments": 0,
"created_at": "2024-12-01T17:00:00Z",
"updated_at": "2024-12-01T17:00:00Z",
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Question\n\nI'm trying to create a custom block for my specific use case. I've read the documentation but I'm not sure how to:\n\n1. Define the input/output schema\n2. Handle authentication\n3. Test my block locally\n\nCan someone point me to examples or provide guidance?\n\n## Environment\n\n- AutoGPT Platform version: latest\n- Python: 3.11",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/timeline"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T17:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"has_discussions": true,
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "curious-user",
"id": 22222222,
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/curious-user",
"html_url": "https://github.com/curious-user",
"type": "User",
"site_admin": false
}
}

View File

@@ -0,0 +1,112 @@
{
"action": "opened",
"issue": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345",
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"labels_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/labels{/name}",
"comments_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/comments",
"events_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/events",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/issues/12345",
"id": 2000000001,
"node_id": "I_kwDOJKSTjM5wXXXX",
"number": 12345,
"title": "Bug: Application crashes when processing large files",
"user": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
},
"labels": [
{
"id": 5272676214,
"node_id": "LA_kwDOJKSTjM8AAAABOkandg",
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/labels/bug",
"name": "bug",
"color": "d73a4a",
"default": true,
"description": "Something isn't working"
}
],
"state": "open",
"locked": false,
"assignee": null,
"assignees": [],
"milestone": null,
"comments": 0,
"created_at": "2024-12-01T16:00:00Z",
"updated_at": "2024-12-01T16:00:00Z",
"closed_at": null,
"author_association": "NONE",
"active_lock_reason": null,
"body": "## Description\n\nWhen I try to process a file larger than 100MB, the application crashes with an out of memory error.\n\n## Steps to Reproduce\n\n1. Open the application\n2. Select a file larger than 100MB\n3. Click 'Process'\n4. Application crashes\n\n## Expected Behavior\n\nThe application should handle large files gracefully.\n\n## Environment\n\n- OS: Ubuntu 22.04\n- Python: 3.11\n- AutoGPT Version: 1.0.0",
"reactions": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/reactions",
"total_count": 0,
"+1": 0,
"-1": 0,
"laugh": 0,
"hooray": 0,
"confused": 0,
"heart": 0,
"rocket": 0,
"eyes": 0
},
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/timeline",
"state_reason": null
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T16:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"open_issues_count": 190,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "bug-reporter",
"id": 11111111,
"node_id": "MDQ6VXNlcjExMTExMTEx",
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/bug-reporter",
"html_url": "https://github.com/bug-reporter",
"type": "User",
"site_admin": false
}
}

View File

@@ -0,0 +1,97 @@
{
"action": "published",
"release": {
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789",
"assets_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets",
"upload_url": "https://uploads.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets{?name,label}",
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/tag/v1.0.0",
"id": 123456789,
"author": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
},
"node_id": "RE_kwDOJKSTjM4HWwAA",
"tag_name": "v1.0.0",
"target_commitish": "master",
"name": "AutoGPT Platform v1.0.0",
"draft": false,
"prerelease": false,
"created_at": "2024-12-01T10:00:00Z",
"published_at": "2024-12-01T12:00:00Z",
"assets": [
{
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/assets/987654321",
"id": 987654321,
"node_id": "RA_kwDOJKSTjM4HWwBB",
"name": "autogpt-v1.0.0.zip",
"label": "Release Package",
"content_type": "application/zip",
"state": "uploaded",
"size": 52428800,
"download_count": 0,
"created_at": "2024-12-01T11:30:00Z",
"updated_at": "2024-12-01T11:35:00Z",
"browser_download_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/download/v1.0.0/autogpt-v1.0.0.zip"
}
],
"tarball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/tarball/v1.0.0",
"zipball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/zipball/v1.0.0",
"body": "## What's New\n\n- Feature 1: Amazing new capability\n- Feature 2: Performance improvements\n- Bug fixes and stability improvements\n\n## Breaking Changes\n\nNone\n\n## Contributors\n\nThanks to all our contributors!"
},
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T12:00:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170000,
"watchers_count": 170000,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "ntindle",
"id": 12345678,
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/ntindle",
"html_url": "https://github.com/ntindle",
"type": "User",
"site_admin": false
}
}

View File

@@ -0,0 +1,53 @@
{
"action": "created",
"starred_at": "2024-12-01T15:30:00Z",
"repository": {
"id": 614765452,
"node_id": "R_kgDOJKSTjA",
"name": "AutoGPT",
"full_name": "Significant-Gravitas/AutoGPT",
"private": false,
"owner": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"url": "https://api.github.com/users/Significant-Gravitas",
"html_url": "https://github.com/Significant-Gravitas",
"type": "Organization",
"site_admin": false
},
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
"fork": false,
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
"created_at": "2023-03-16T09:21:07Z",
"updated_at": "2024-12-01T15:30:00Z",
"pushed_at": "2024-12-01T12:00:00Z",
"stargazers_count": 170001,
"watchers_count": 170001,
"language": "Python",
"forks_count": 45000,
"visibility": "public",
"default_branch": "master"
},
"organization": {
"login": "Significant-Gravitas",
"id": 130738209,
"node_id": "O_kgDOB8roIQ",
"url": "https://api.github.com/orgs/Significant-Gravitas",
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
"description": ""
},
"sender": {
"login": "awesome-contributor",
"id": 98765432,
"node_id": "MDQ6VXNlcjk4NzY1NDMy",
"avatar_url": "https://avatars.githubusercontent.com/u/98765432?v=4",
"gravatar_id": "",
"url": "https://api.github.com/users/awesome-contributor",
"html_url": "https://github.com/awesome-contributor",
"type": "User",
"site_admin": false
}
}

View File

@@ -159,3 +159,391 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
# --8<-- [end:GithubTriggerExample] # --8<-- [end:GithubTriggerExample]
class GithubStarTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub star events - useful for milestone celebrations."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "star.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#star
"""
created: bool = False
deleted: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The star events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The star event that triggered the webhook ('created' or 'deleted')"
)
starred_at: str = SchemaField(
description="ISO timestamp when the repo was starred (empty if deleted)"
)
stargazers_count: int = SchemaField(
description="Current number of stars on the repository"
)
repository_name: str = SchemaField(
description="Full name of the repository (owner/repo)"
)
repository_url: str = SchemaField(description="URL to the repository")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="551e0a35-100b-49b7-89b8-3031322239b6",
description="This block triggers on GitHub star events. "
"Useful for celebrating milestones (e.g., 1k, 10k stars) or tracking engagement.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubStarTriggerBlock.Input,
output_schema=GithubStarTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="star.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("starred_at", example_payload.get("starred_at", "")),
("stargazers_count", example_payload["repository"]["stargazers_count"]),
("repository_name", example_payload["repository"]["full_name"]),
("repository_url", example_payload["repository"]["html_url"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
yield "event", input_data.payload["action"]
yield "starred_at", input_data.payload.get("starred_at", "")
yield "stargazers_count", input_data.payload["repository"]["stargazers_count"]
yield "repository_name", input_data.payload["repository"]["full_name"]
yield "repository_url", input_data.payload["repository"]["html_url"]
class GithubReleaseTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub release events - ideal for announcing new versions."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "release.published.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#release
"""
published: bool = False
unpublished: bool = False
created: bool = False
edited: bool = False
deleted: bool = False
prereleased: bool = False
released: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The release events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The release event that triggered the webhook (e.g., 'published')"
)
release: dict = SchemaField(description="The full release object")
release_url: str = SchemaField(description="URL to the release page")
tag_name: str = SchemaField(description="The release tag name (e.g., 'v1.0.0')")
release_name: str = SchemaField(description="Human-readable release name")
body: str = SchemaField(description="Release notes/description")
prerelease: bool = SchemaField(description="Whether this is a prerelease")
draft: bool = SchemaField(description="Whether this is a draft release")
assets: list = SchemaField(description="List of release assets/files")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
description="This block triggers on GitHub release events. "
"Perfect for automating announcements to Discord, Twitter, or other platforms.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubReleaseTriggerBlock.Input,
output_schema=GithubReleaseTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="release.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"published": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("release", example_payload["release"]),
("release_url", example_payload["release"]["html_url"]),
("tag_name", example_payload["release"]["tag_name"]),
("release_name", example_payload["release"]["name"]),
("body", example_payload["release"]["body"]),
("prerelease", example_payload["release"]["prerelease"]),
("draft", example_payload["release"]["draft"]),
("assets", example_payload["release"]["assets"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
release = input_data.payload["release"]
yield "event", input_data.payload["action"]
yield "release", release
yield "release_url", release["html_url"]
yield "tag_name", release["tag_name"]
yield "release_name", release.get("name", "")
yield "body", release.get("body", "")
yield "prerelease", release["prerelease"]
yield "draft", release["draft"]
yield "assets", release["assets"]
class GithubIssuesTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub issues events - great for triage and notifications."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "issues.opened.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#issues
"""
opened: bool = False
edited: bool = False
deleted: bool = False
closed: bool = False
reopened: bool = False
assigned: bool = False
unassigned: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
transferred: bool = False
milestoned: bool = False
demilestoned: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The issue events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The issue event that triggered the webhook (e.g., 'opened')"
)
number: int = SchemaField(description="The issue number")
issue: dict = SchemaField(description="The full issue object")
issue_url: str = SchemaField(description="URL to the issue")
issue_title: str = SchemaField(description="The issue title")
issue_body: str = SchemaField(description="The issue body/description")
labels: list = SchemaField(description="List of labels on the issue")
assignees: list = SchemaField(description="List of assignees")
state: str = SchemaField(description="Issue state ('open' or 'closed')")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="b2605464-e486-4bf4-aad3-d8a213c8a48a",
description="This block triggers on GitHub issues events. "
"Useful for automated triage, notifications, and welcoming first-time contributors.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubIssuesTriggerBlock.Input,
output_schema=GithubIssuesTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="issues.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"opened": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["issue"]["number"]),
("issue", example_payload["issue"]),
("issue_url", example_payload["issue"]["html_url"]),
("issue_title", example_payload["issue"]["title"]),
("issue_body", example_payload["issue"]["body"]),
("labels", example_payload["issue"]["labels"]),
("assignees", example_payload["issue"]["assignees"]),
("state", example_payload["issue"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
issue = input_data.payload["issue"]
yield "event", input_data.payload["action"]
yield "number", issue["number"]
yield "issue", issue
yield "issue_url", issue["html_url"]
yield "issue_title", issue["title"]
yield "issue_body", issue.get("body") or ""
yield "labels", issue["labels"]
yield "assignees", issue["assignees"]
yield "state", issue["state"]
class GithubDiscussionTriggerBlock(GitHubTriggerBase, Block):
"""Trigger block for GitHub discussion events - perfect for community Q&A sync."""
EXAMPLE_PAYLOAD_FILE = (
Path(__file__).parent / "example_payloads" / "discussion.created.json"
)
class Input(GitHubTriggerBase.Input):
class EventsFilter(BaseModel):
"""
https://docs.github.com/en/webhooks/webhook-events-and-payloads#discussion
"""
created: bool = False
edited: bool = False
deleted: bool = False
answered: bool = False
unanswered: bool = False
labeled: bool = False
unlabeled: bool = False
locked: bool = False
unlocked: bool = False
category_changed: bool = False
transferred: bool = False
pinned: bool = False
unpinned: bool = False
events: EventsFilter = SchemaField(
title="Events", description="The discussion events to subscribe to"
)
class Output(GitHubTriggerBase.Output):
event: str = SchemaField(
description="The discussion event that triggered the webhook"
)
number: int = SchemaField(description="The discussion number")
discussion: dict = SchemaField(description="The full discussion object")
discussion_url: str = SchemaField(description="URL to the discussion")
title: str = SchemaField(description="The discussion title")
body: str = SchemaField(description="The discussion body")
category: dict = SchemaField(description="The discussion category object")
category_name: str = SchemaField(description="Name of the category")
state: str = SchemaField(description="Discussion state")
def __init__(self):
from backend.integrations.webhooks.github import GithubWebhookType
example_payload = json.loads(
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
)
super().__init__(
id="87f847b3-d81a-424e-8e89-acadb5c9d52b",
description="This block triggers on GitHub Discussions events. "
"Great for syncing Q&A to Discord or auto-responding to common questions. "
"Note: Discussions must be enabled on the repository.",
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
input_schema=GithubDiscussionTriggerBlock.Input,
output_schema=GithubDiscussionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.GITHUB,
webhook_type=GithubWebhookType.REPO,
resource_format="{repo}",
event_filter_input="events",
event_format="discussion.{event}",
),
test_input={
"repo": "Significant-Gravitas/AutoGPT",
"events": {"created": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": example_payload,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", example_payload),
("triggered_by_user", example_payload["sender"]),
("event", example_payload["action"]),
("number", example_payload["discussion"]["number"]),
("discussion", example_payload["discussion"]),
("discussion_url", example_payload["discussion"]["html_url"]),
("title", example_payload["discussion"]["title"]),
("body", example_payload["discussion"]["body"]),
("category", example_payload["discussion"]["category"]),
("category_name", example_payload["discussion"]["category"]["name"]),
("state", example_payload["discussion"]["state"]),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
async for name, value in super().run(input_data, **kwargs):
yield name, value
discussion = input_data.payload["discussion"]
yield "event", input_data.payload["action"]
yield "number", discussion["number"]
yield "discussion", discussion
yield "discussion_url", discussion["html_url"]
yield "title", discussion["title"]
yield "body", discussion.get("body") or ""
yield "category", discussion["category"]
yield "category_name", discussion["category"]["name"]
yield "state", discussion["state"]

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,11 @@
import logging import logging
import re import re
from collections import Counter from collections import Counter
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
import backend.blocks.llm as llm import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import ( from backend.data.block import (
@@ -20,16 +23,41 @@ from backend.data.dynamic_fields import (
is_dynamic_field, is_dynamic_field,
is_tool_pin, is_tool_pin,
) )
from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json from backend.util import json
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.data.graph import Link, Node from backend.data.graph import Link, Node
from backend.executor.manager import ExecutionProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolInfo(BaseModel):
"""Processed tool call information."""
tool_call: Any # The original tool call object from LLM response
tool_name: str # The function name
tool_def: dict[str, Any] # The tool definition from tool_functions
input_data: dict[str, Any] # Processed input data ready for tool execution
field_mapping: dict[str, str] # Field name mapping for the tool
class ExecutionParams(BaseModel):
"""Tool execution parameters."""
user_id: str
graph_id: str
node_id: str
graph_version: int
graph_exec_id: str
node_exec_id: str
execution_context: "ExecutionContext"
def _get_tool_requests(entry: dict[str, Any]) -> list[str]: def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
""" """
Return a list of tool_call_ids if the entry is a tool request. Return a list of tool_call_ids if the entry is a tool request.
@@ -105,6 +133,50 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
return {"role": "tool", "tool_call_id": call_id, "content": content} return {"role": "tool", "tool_call_id": call_id, "content": content}
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Combine multiple Anthropic tool responses into a single user message.
For non-Anthropic formats, returns the original list unchanged.
"""
if len(tool_outputs) <= 1:
return tool_outputs
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
anthropic_responses = [
output
for output in tool_outputs
if (
output.get("role") == "user"
and output.get("type") == "message"
and isinstance(output.get("content"), list)
and any(
item.get("type") == "tool_result"
for item in output.get("content", [])
if isinstance(item, dict)
)
)
]
if len(anthropic_responses) > 1:
combined_content = [
item for response in anthropic_responses for item in response["content"]
]
combined_response = {
"role": "user",
"type": "message",
"content": combined_content,
}
non_anthropic_responses = [
output for output in tool_outputs if output not in anthropic_responses
]
return [combined_response] + non_anthropic_responses
return tool_outputs
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]: def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
""" """
Safely convert raw_response to dictionary format for conversation history. Safely convert raw_response to dictionary format for conversation history.
@@ -204,6 +276,17 @@ class SmartDecisionMakerBlock(Block):
default="localhost:11434", default="localhost:11434",
description="Ollama host for local models", description="Ollama host for local models",
) )
agent_mode_max_iterations: int = SchemaField(
title="Agent Mode Max Iterations",
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
advanced=True,
default=0,
)
conversation_compaction: bool = SchemaField(
default=True,
title="Context window auto-compaction",
description="Automatically compact the context window once it hits the limit",
)
@classmethod @classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]: def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
@@ -506,6 +589,7 @@ class SmartDecisionMakerBlock(Block):
Returns the response if successful, raises ValueError if validation fails. Returns the response if successful, raises ValueError if validation fails.
""" """
resp = await llm.llm_call( resp = await llm.llm_call(
compress_prompt_to_fit=input_data.conversation_compaction,
credentials=credentials, credentials=credentials,
llm_model=input_data.model, llm_model=input_data.model,
prompt=current_prompt, prompt=current_prompt,
@@ -593,6 +677,291 @@ class SmartDecisionMakerBlock(Block):
return resp return resp
def _process_tool_calls(
self, response, tool_functions: list[dict[str, Any]]
) -> list[ToolInfo]:
"""Process tool calls and extract tool definitions, arguments, and input data.
Returns a list of tool info dicts with:
- tool_call: The original tool call object
- tool_name: The function name
- tool_def: The tool definition from tool_functions
- input_data: Processed input data dict (includes None values)
- field_mapping: Field name mapping for the tool
"""
if not response.tool_calls:
return []
processed_tools = []
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if not tool_def:
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
continue
# Build input data for the tool
input_data = {}
field_mapping = tool_def["function"].get("_field_mapping", {})
if "function" in tool_def and "parameters" in tool_def["function"]:
expected_args = tool_def["function"]["parameters"].get("properties", {})
for clean_arg_name in expected_args:
original_field_name = field_mapping.get(
clean_arg_name, clean_arg_name
)
arg_value = tool_args.get(clean_arg_name)
# Include all expected parameters, even if None (for backward compatibility with tests)
input_data[original_field_name] = arg_value
processed_tools.append(
ToolInfo(
tool_call=tool_call,
tool_name=tool_name,
tool_def=tool_def,
input_data=input_data,
field_mapping=field_mapping,
)
)
return processed_tools
def _update_conversation(
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(assistant_message)
if tool_outputs:
prompt.extend(tool_outputs)
async def _execute_single_tool_with_manager(
self,
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
from backend.data.execution import NodeExecutionEntry
tool_call = tool_info.tool_call
tool_def = tool_info.tool_def
raw_input_data = tool_info.input_data
# Get sink node and field mapping
sink_node_id = tool_def["function"]["_sink_node_id"]
# Use proper database operations for tool execution
db_client = get_database_manager_async_client()
# Get target node
target_node = await db_client.get_node(sink_node_id)
if not target_node:
raise ValueError(f"Target node {sink_node_id} not found")
# Create proper node execution using upsert_execution_input
node_exec_result = None
final_input_data = None
# Add all inputs to the execution
if not raw_input_data:
raise ValueError(f"Tool call has no input data: {tool_call}")
for input_name, input_value in raw_input_data.items():
node_exec_result, final_input_data = await db_client.upsert_execution_input(
node_id=sink_node_id,
graph_exec_id=execution_params.graph_exec_id,
input_name=input_name,
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
user_id=execution_params.user_id,
graph_exec_id=execution_params.graph_exec_id,
graph_id=execution_params.graph_id,
graph_version=execution_params.graph_version,
node_exec_id=node_exec_result.node_exec_id,
node_id=sink_node_id,
block_id=target_node.block_id,
inputs=final_input_data or {},
execution_context=execution_params.execution_context,
)
# Use the execution manager to execute the tool node
try:
# Get NodeExecutionProgress from the execution manager's running nodes
node_exec_progress = execution_processor.running_node_execution[
sink_node_id
]
# Use the execution manager's own graph stats
graph_stats_pair = (
execution_processor.execution_stats,
execution_processor.execution_stats_lock,
)
# Create a completed future for the task tracking system
node_exec_future = Future()
node_exec_progress.add_task(
node_exec_id=node_exec_result.node_exec_id,
task=node_exec_future,
)
# Execute the node directly since we're in the SmartDecisionMaker context
node_exec_future.set_result(
await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
node_exec_result.node_exec_id
)
# Create tool response
tool_response_content = (
json.dumps(node_outputs)
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(tool_call.id, tool_response_content)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id, f"Tool execution failed: {str(e)}"
)
async def _execute_tools_agent_mode(
self,
input_data,
credentials,
tool_functions: list[dict[str, Any]],
prompt: list[dict],
graph_exec_id: str,
node_id: str,
node_exec_id: str,
user_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
# Execution parameters for tool execution
execution_params = ExecutionParams(
user_id=user_id,
graph_id=graph_id,
node_id=node_id,
graph_version=graph_version,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
execution_context=execution_context,
)
current_prompt = list(prompt)
while max_iterations < 0 or iteration < max_iterations:
iteration += 1
logger.debug(f"Agent mode iteration {iteration}")
# Prepare prompt for this iteration
iteration_prompt = list(current_prompt)
# On the last iteration, add a special system message to encourage completion
if max_iterations > 0 and iteration == max_iterations:
last_iteration_message = {
"role": "system",
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
"Try to complete the task with the information you have. If you cannot fully complete it, "
"provide a summary of what you've accomplished and what remains to be done. "
"Prefer finishing with a clear response rather than making additional tool calls.",
}
iteration_prompt.append(last_iteration_message)
# Get LLM response
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, iteration_prompt, tool_functions
)
except Exception as e:
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
return
# Process tool calls
processed_tools = self._process_tool_calls(response, tool_functions)
# If no tool calls, we're done
if not processed_tools:
yield "finished", response.response
self._update_conversation(current_prompt, response)
yield "conversations", current_prompt
return
# Execute tools and collect responses
tool_outputs = []
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info, execution_params, execution_processor
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id, f"Error: {str(e)}"
)
tool_outputs.append(error_response)
tool_outputs = _combine_tool_responses(tool_outputs)
self._update_conversation(current_prompt, response, tool_outputs)
# Yield intermediate conversation state
yield "conversations", current_prompt
# If we reach max iterations, yield the current state
if max_iterations < 0:
yield "finished", f"Agent mode completed after {iteration} iterations"
else:
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
yield "conversations", current_prompt
async def run( async def run(
self, self,
input_data: Input, input_data: Input,
@@ -603,8 +972,12 @@ class SmartDecisionMakerBlock(Block):
graph_exec_id: str, graph_exec_id: str,
node_exec_id: str, node_exec_id: str,
user_id: str, user_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id) tool_functions = await self._create_tool_node_signatures(node_id)
yield "tool_functions", json.dumps(tool_functions) yield "tool_functions", json.dumps(tool_functions)
@@ -648,24 +1021,52 @@ class SmartDecisionMakerBlock(Block):
input_data.prompt = llm.fmt.format_string(input_data.prompt, values) input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values) input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any( if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
): ):
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt}) prompt.append(
{
"role": "system",
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
}
)
if input_data.prompt and not any( if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
): ):
prompt.append({"role": "user", "content": prefix + input_data.prompt}) prompt.append(
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
)
# Execute tools based on the selected mode
if input_data.agent_mode_max_iterations != 0:
# In agent mode, execute tools directly in a loop until finished
async for result in self._execute_tools_agent_mode(
input_data=input_data,
credentials=credentials,
tool_functions=tool_functions,
prompt=prompt,
graph_exec_id=graph_exec_id,
node_id=node_id,
node_exec_id=node_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
execution_processor=execution_processor,
):
yield result
return
# One-off mode: single LLM call and yield tool calls for external execution
current_prompt = list(prompt) current_prompt = list(prompt)
max_attempts = max(1, int(input_data.retry)) max_attempts = max(1, int(input_data.retry))
response = None response = None
last_error = None last_error = None
for attempt in range(max_attempts): for _ in range(max_attempts):
try: try:
response = await self._attempt_llm_call_with_validation( response = await self._attempt_llm_call_with_validation(
credentials, input_data, current_prompt, tool_functions credentials, input_data, current_prompt, tool_functions

View File

@@ -1,7 +1,11 @@
import logging import logging
import threading
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from backend.data.execution import ExecutionContext
from backend.data.model import ProviderName, User from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer from backend.server.rest_api import AgentServer
@@ -17,10 +21,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
async def create_credentials(s: SpinTestServer, u: User): async def create_credentials(s: SpinTestServer, u: User):
import backend.blocks.llm as llm import backend.blocks.llm as llm_module
provider = ProviderName.OPENAI provider = ProviderName.OPENAI
credentials = llm.TEST_CREDENTIALS credentials = llm_module.TEST_CREDENTIALS
return await s.agent_server.test_create_credentials(u.id, provider, credentials) return await s.agent_server.test_create_credentials(u.id, provider, credentials)
@@ -196,8 +200,6 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_tracks_llm_stats(): async def test_smart_decision_maker_tracks_llm_stats():
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats.""" """Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -216,7 +218,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
} }
# Mock the _create_tool_node_signatures method to avoid database calls # Mock the _create_tool_node_signatures method to avoid database calls
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
@@ -234,10 +235,19 @@ async def test_smart_decision_maker_tracks_llm_stats():
prompt="Should I continue with this task?", prompt="Should I continue with this task?",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Execute the block # Execute the block
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -246,6 +256,9 @@ async def test_smart_decision_maker_tracks_llm_stats():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -263,8 +276,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_parameter_validation(): async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters.""" """Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -311,8 +322,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_with_typo.reasoning = None mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None} mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -329,8 +338,17 @@ async def test_smart_decision_maker_parameter_validation():
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
) )
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError after retries due to typo'd parameter name # Should raise ValueError after retries due to typo'd parameter name
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
outputs = {} outputs = {}
@@ -342,6 +360,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -368,8 +389,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_missing_required.reasoning = None mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None} mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -385,8 +404,17 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError due to missing required parameter # Should raise ValueError due to missing required parameter
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
outputs = {} outputs = {}
@@ -398,6 +426,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -418,8 +449,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_valid.reasoning = None mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None} mock_response_valid.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -435,10 +464,19 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Should succeed - optional parameter missing is OK # Should succeed - optional parameter missing is OK
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -447,6 +485,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -472,8 +513,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_all_params.reasoning = None mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None} mock_response_all_params.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -489,10 +528,19 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Should succeed with all parameters # Should succeed with all parameters
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -501,6 +549,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -513,8 +564,6 @@ async def test_smart_decision_maker_parameter_validation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_raw_response_conversion(): async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism.""" """Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -584,7 +633,6 @@ async def test_smart_decision_maker_raw_response_conversion():
) )
# Mock llm_call to return different responses on different calls # Mock llm_call to return different responses on different calls
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock "backend.blocks.llm.llm_call", new_callable=AsyncMock
@@ -603,10 +651,19 @@ async def test_smart_decision_maker_raw_response_conversion():
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, retry=2,
agent_mode_max_iterations=0,
) )
# Should succeed after retry, demonstrating our helper function works # Should succeed after retry, demonstrating our helper function works
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -615,6 +672,9 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -650,8 +710,6 @@ async def test_smart_decision_maker_raw_response_conversion():
"I'll help you with that." # Ollama returns string "I'll help you with that." # Ollama returns string
) )
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -666,9 +724,18 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Simple prompt", prompt="Simple prompt",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -677,6 +744,9 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -696,8 +766,6 @@ async def test_smart_decision_maker_raw_response_conversion():
"content": "Test response", "content": "Test response",
} # Dict format } # Dict format
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -712,6 +780,160 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Another test", prompt="Another test",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
assert "finished" in outputs
assert outputs["finished"] == "Test response"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_mode():
"""Test that agent mode executes tools directly and loops until finished."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call that requires multiple iterations
mock_tool_call_1 = MagicMock()
mock_tool_call_1.id = "call_1"
mock_tool_call_1.function.name = "search_keywords"
mock_tool_call_1.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call_1]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = "Using search tool"
mock_response_1.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1", "type": "function"}],
}
# Final response with no tool calls (finished)
mock_response_2 = MagicMock()
mock_response_2.response = "Task completed successfully"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {
"role": "assistant",
"content": "Task completed successfully",
}
# Mock the LLM call to return different responses on each iteration
llm_call_mock = AsyncMock()
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
# Mock tool node signatures
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
# Mock database and execution components
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block-id"
mock_db_client.get_node.return_value = mock_node
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_input_data,
)
# No longer need mock_execute_node since we use execution_processor.on_node_execution
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
), patch(
"backend.integrations.creds_manager.IntegrationCredentialsManager"
):
# Create a mock execution context
mock_execution_context = ExecutionContext(
safe_mode=False,
)
# Create a mock execution processor for agent mode tests
mock_execution_processor = AsyncMock()
# Configure the execution processor mock with required attributes
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
# Mock the on_node_execution method to return successful stats
mock_node_stats = MagicMock()
mock_node_stats.error = None # No error
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": {"status": "success", "data": "search completed"}
}
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
) )
outputs = {} outputs = {}
@@ -723,8 +945,115 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
# Verify agent mode behavior
assert "tool_functions" in outputs # tool_functions is yielded in both modes
assert "finished" in outputs assert "finished" in outputs
assert outputs["finished"] == "Test response" assert outputs["finished"] == "Task completed successfully"
assert "conversations" in outputs
# Verify the conversation includes tool responses
conversations = outputs["conversations"]
assert len(conversations) > 2 # Should have multiple conversation entries
# Verify LLM was called twice (once for tool call, once for finish)
assert llm_call_mock.call_count == 2
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
@pytest.mark.asyncio
async def test_smart_decision_maker_traditional_mode_default():
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "search_keywords"
mock_tool_call.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
):
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify traditional mode behavior
assert (
"tool_functions" in outputs
) # Should yield tool_functions in traditional mode
assert (
"tools_^_test-sink-node-id_~_query" in outputs
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs

View File

@@ -1,7 +1,7 @@
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling.""" """Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json import json
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
@@ -308,10 +308,47 @@ async def test_output_yielding_with_dynamic_fields():
) as mock_llm: ) as mock_llm:
mock_llm.return_value = mock_response mock_llm.return_value = mock_response
# Mock the function signature creation # Mock the database manager to avoid HTTP calls during tool execution
with patch.object( with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig: ) as mock_sig:
# Set up the mock database manager
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "CreateDictionaryBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Create Dictionary"
mock_db_client.get_node.return_value = mock_target_node
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Mock the output retrieval
mock_outputs = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
mock_sig.return_value = [ mock_sig.return_value = [
{ {
"type": "function", "type": "function",
@@ -337,10 +374,16 @@ async def test_output_yielding_with_dynamic_fields():
prompt="Create a user dictionary", prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT, credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O, model=llm.LlmModel.GPT4O,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
) )
# Run the block # Run the block
outputs = {} outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
async for output_name, output_value in block.run( async for output_name, output_value in block.run(
input_data, input_data,
credentials=llm.TEST_CREDENTIALS, credentials=llm.TEST_CREDENTIALS,
@@ -349,6 +392,9 @@ async def test_output_yielding_with_dynamic_fields():
graph_exec_id="test_exec", graph_exec_id="test_exec",
node_exec_id="test_node_exec", node_exec_id="test_node_exec",
user_id="test_user", user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_value outputs[output_name] = output_value
@@ -511,6 +557,37 @@ async def test_validation_errors_dont_pollute_conversation():
} }
] ]
# Mock the database manager to avoid HTTP calls during tool execution
with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager:
# Set up the mock database manager for agent mode
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "TestBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Test Block"
mock_db_client.get_node.return_value = mock_target_node
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {"correct_param": "value"}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Mock the output retrieval
mock_outputs = {"correct_param": "value"}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
# Create input data # Create input data
from backend.blocks import llm from backend.blocks import llm
@@ -519,10 +596,41 @@ async def test_validation_errors_dont_pollute_conversation():
credentials=llm.TEST_CREDENTIALS_INPUT, credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O, model=llm.LlmModel.GPT4O,
retry=3, # Allow retries retry=3, # Allow retries
agent_mode_max_iterations=1,
) )
# Run the block # Run the block
outputs = {} outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a proper mock execution processor for agent mode
from collections import defaultdict
mock_execution_processor = AsyncMock()
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = MagicMock()
# Create a mock NodeExecutionProgress for the sink node
mock_node_exec_progress = MagicMock()
mock_node_exec_progress.add_task = MagicMock()
mock_node_exec_progress.pop_output = MagicMock(
return_value=None
) # No outputs to process
# Set up running_node_execution as a defaultdict that returns our mock for any key
mock_execution_processor.running_node_execution = defaultdict(
lambda: mock_node_exec_progress
)
# Mock the on_node_execution method that gets called during tool execution
mock_node_stats = MagicMock()
mock_node_stats.error = None
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
async for output_name, output_value in block.run( async for output_name, output_value in block.run(
input_data, input_data,
credentials=llm.TEST_CREDENTIALS, credentials=llm.TEST_CREDENTIALS,
@@ -531,16 +639,20 @@ async def test_validation_errors_dont_pollute_conversation():
graph_exec_id="test_exec", graph_exec_id="test_exec",
node_exec_id="test_node_exec", node_exec_id="test_node_exec",
user_id="test_user", user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_value outputs[output_name] = output_value
# Verify we had 2 LLM calls (initial + retry) # Verify we had at least 1 LLM call
assert call_count == 2 assert call_count >= 1
# Check the final conversation output # Check the final conversation output
final_conversation = outputs.get("conversations", []) final_conversation = outputs.get("conversations", [])
# The final conversation should NOT contain the validation error message # The final conversation should NOT contain validation error messages
# Even if retries don't happen in agent mode, we should not leak errors
error_messages = [ error_messages = [
msg msg
for msg in final_conversation for msg in final_conversation
@@ -550,6 +662,3 @@ async def test_validation_errors_dont_pollute_conversation():
assert ( assert (
len(error_messages) == 0 len(error_messages) == 0
), "Validation error leaked into final conversation" ), "Validation error leaked into final conversation"
# The final conversation should only have the successful response
assert final_conversation[-1]["content"] == "valid"

View File

@@ -6,7 +6,7 @@ from typing import Optional
from autogpt_libs.api_key.keysmith import APIKeySmith from autogpt_libs.api_key.keysmith import APIKeySmith
from prisma.enums import APIKeyPermission, APIKeyStatus from prisma.enums import APIKeyPermission, APIKeyStatus
from prisma.models import APIKey as PrismaAPIKey from prisma.models import APIKey as PrismaAPIKey
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput from prisma.types import APIKeyWhereUniqueInput
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.data.includes import MAX_USER_API_KEYS_FETCH from backend.data.includes import MAX_USER_API_KEYS_FETCH
@@ -83,17 +83,17 @@ async def create_api_key(
generated_key = keysmith.generate_key() generated_key = keysmith.generate_key()
saved_key_obj = await PrismaAPIKey.prisma().create( saved_key_obj = await PrismaAPIKey.prisma().create(
data=APIKeyCreateInput( data={
id=str(uuid.uuid4()), "id": str(uuid.uuid4()),
name=name, "name": name,
head=generated_key.head, "head": generated_key.head,
tail=generated_key.tail, "tail": generated_key.tail,
hash=generated_key.hash, "hash": generated_key.hash,
salt=generated_key.salt, "salt": generated_key.salt,
permissions=permissions, "permissions": [p for p in permissions],
description=description, "description": description,
userId=user_id, "userId": user_id,
) }
) )
return APIKeyInfo.from_db(saved_key_obj), generated_key.key return APIKeyInfo.from_db(saved_key_obj), generated_key.key

View File

@@ -1,327 +0,0 @@
"""
Credential Grant data layer.
Handles database operations for credential grants which allow OAuth clients
to use credentials on behalf of users.
"""
from datetime import datetime, timezone
from typing import Optional
from prisma.enums import CredentialGrantPermission
from prisma.models import CredentialGrant
from backend.data.db import prisma
async def create_credential_grant(
user_id: str,
client_id: str,
credential_id: str,
provider: str,
granted_scopes: list[str],
permissions: list[CredentialGrantPermission],
expires_at: Optional[datetime] = None,
) -> CredentialGrant:
"""
Create a new credential grant.
Args:
user_id: ID of the user granting access
client_id: Database ID of the OAuth client
credential_id: ID of the credential being granted
provider: Provider name (e.g., "google", "github")
granted_scopes: List of integration scopes granted
permissions: List of permissions (USE, DELETE)
expires_at: Optional expiration datetime
Returns:
Created CredentialGrant
"""
return await prisma.credentialgrant.create(
data={ # type: ignore[typeddict-item]
"userId": user_id,
"clientId": client_id,
"credentialId": credential_id,
"provider": provider,
"grantedScopes": granted_scopes,
"permissions": permissions,
"expiresAt": expires_at,
}
)
async def get_credential_grant(
grant_id: str,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
) -> Optional[CredentialGrant]:
"""
Get a credential grant by ID.
Args:
grant_id: Grant ID
user_id: Optional user ID filter
client_id: Optional client database ID filter
Returns:
CredentialGrant or None
"""
where: dict[str, str] = {"id": grant_id}
if user_id:
where["userId"] = user_id
if client_id:
where["clientId"] = client_id
return await prisma.credentialgrant.find_first(where=where) # type: ignore[arg-type]
async def get_grants_for_user_client(
user_id: str,
client_id: str,
include_revoked: bool = False,
include_expired: bool = False,
) -> list[CredentialGrant]:
"""
Get all credential grants for a user-client pair.
Args:
user_id: User ID
client_id: Client database ID
include_revoked: Include revoked grants
include_expired: Include expired grants
Returns:
List of CredentialGrant objects
"""
where: dict[str, str | None] = {
"userId": user_id,
"clientId": client_id,
}
if not include_revoked:
where["revokedAt"] = None
grants = await prisma.credentialgrant.find_many(
where=where, # type: ignore[arg-type]
order={"createdAt": "desc"},
)
# Filter expired if needed
if not include_expired:
now = datetime.now(timezone.utc)
grants = [g for g in grants if g.expiresAt is None or g.expiresAt > now]
return grants
async def get_grants_for_credential(
user_id: str,
credential_id: str,
) -> list[CredentialGrant]:
"""
Get all active grants for a specific credential.
Args:
user_id: User ID
credential_id: Credential ID
Returns:
List of active CredentialGrant objects
"""
now = datetime.now(timezone.utc)
grants = await prisma.credentialgrant.find_many(
where={
"userId": user_id,
"credentialId": credential_id,
"revokedAt": None,
},
include={"Client": True},
)
# Filter expired
return [g for g in grants if g.expiresAt is None or g.expiresAt > now]
async def get_grant_by_credential_and_client(
user_id: str,
credential_id: str,
client_id: str,
) -> Optional[CredentialGrant]:
"""
Get the grant for a specific credential and client.
Args:
user_id: User ID
credential_id: Credential ID
client_id: Client database ID
Returns:
CredentialGrant or None
"""
return await prisma.credentialgrant.find_first(
where={
"userId": user_id,
"credentialId": credential_id,
"clientId": client_id,
"revokedAt": None,
}
)
async def update_grant_scopes(
grant_id: str,
granted_scopes: list[str],
) -> CredentialGrant:
"""
Update the granted scopes for a credential grant.
Args:
grant_id: Grant ID
granted_scopes: New list of granted scopes
Returns:
Updated CredentialGrant
"""
result = await prisma.credentialgrant.update(
where={"id": grant_id},
data={"grantedScopes": granted_scopes},
)
if result is None:
raise ValueError(f"Grant {grant_id} not found")
return result
async def update_grant_last_used(grant_id: str) -> None:
"""
Update the lastUsedAt timestamp for a grant.
Args:
grant_id: Grant ID
"""
await prisma.credentialgrant.update(
where={"id": grant_id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
async def revoke_grant(grant_id: str) -> CredentialGrant:
"""
Revoke a credential grant.
Args:
grant_id: Grant ID
Returns:
Revoked CredentialGrant
"""
result = await prisma.credentialgrant.update(
where={"id": grant_id},
data={"revokedAt": datetime.now(timezone.utc)},
)
if result is None:
raise ValueError(f"Grant {grant_id} not found")
return result
async def revoke_grants_for_credential(
user_id: str,
credential_id: str,
) -> int:
"""
Revoke all grants for a specific credential.
Args:
user_id: User ID
credential_id: Credential ID
Returns:
Number of grants revoked
"""
return await prisma.credentialgrant.update_many(
where={
"userId": user_id,
"credentialId": credential_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
async def revoke_grants_for_client(
user_id: str,
client_id: str,
) -> int:
"""
Revoke all grants for a specific client.
Args:
user_id: User ID
client_id: Client database ID
Returns:
Number of grants revoked
"""
return await prisma.credentialgrant.update_many(
where={
"userId": user_id,
"clientId": client_id,
"revokedAt": None,
},
data={"revokedAt": datetime.now(timezone.utc)},
)
async def delete_grant(grant_id: str) -> None:
"""
Permanently delete a credential grant.
Args:
grant_id: Grant ID
"""
await prisma.credentialgrant.delete(where={"id": grant_id})
async def check_grant_permission(
grant_id: str,
required_permission: CredentialGrantPermission,
) -> bool:
"""
Check if a grant has a specific permission.
Args:
grant_id: Grant ID
required_permission: Permission to check
Returns:
True if grant has the permission
"""
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
if not grant:
return False
return required_permission in grant.permissions
async def is_grant_valid(grant_id: str) -> bool:
"""
Check if a grant is valid (not revoked and not expired).
Args:
grant_id: Grant ID
Returns:
True if grant is valid
"""
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
if not grant:
return False
if grant.revokedAt:
return False
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
return False
return True

View File

@@ -11,7 +11,6 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import UserCredit from backend.data.credit import UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -22,11 +21,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests.""" """Create a test user for ceiling tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=user_id, "id": user_id,
email=f"test-{user_id}@example.com", "email": f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}", "name": f"Test User {user_id[:8]}",
) }
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -34,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )

View File

@@ -14,7 +14,6 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError from backend.util.exceptions import InsufficientBalanceError
@@ -29,11 +28,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance.""" """Create a test user with initial balance."""
try: try:
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=user_id, "id": user_id,
email=f"test-{user_id}@example.com", "email": f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}", "name": f"Test User {user_id[:8]}",
) }
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -42,10 +41,7 @@ async def create_test_user(user_id: str) -> None:
# Ensure UserBalance record exists # Ensure UserBalance record exists
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )
@@ -346,10 +342,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
# First, set balance near max # First, set balance near max
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100), "create": {"userId": user_id, "balance": max_int - 100},
update={"balance": max_int - 100}, "update": {"balance": max_int - 100},
), },
) )
# Try to add more than possible - should clamp to POSTGRES_INT_MAX # Try to add more than possible - should clamp to POSTGRES_INT_MAX

View File

@@ -8,7 +8,6 @@ which would have caught the CreditTransactionType enum casting bug.
import pytest import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserCreateInput
from backend.data.credit import ( from backend.data.credit import (
AutoTopUpConfig, AutoTopUpConfig,
@@ -30,12 +29,12 @@ async def cleanup_test_user():
# Create the user first # Create the user first
try: try:
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=user_id, "id": user_id,
email=f"test-{user_id}@example.com", "email": f"test-{user_id}@example.com",
topUpConfig=SafeJson({}), "topUpConfig": SafeJson({}),
timezone="UTC", "timezone": "UTC",
) }
) )
except Exception: except Exception:
# User might already exist, that's fine # User might already exist, that's fine

View File

@@ -12,12 +12,6 @@ import pytest
import stripe import stripe
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserCreateInput,
)
from backend.data.credit import UserCredit from backend.data.credit import UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -41,32 +35,32 @@ async def setup_test_user_with_topup():
# Create user # Create user
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=REFUND_TEST_USER_ID, "id": REFUND_TEST_USER_ID,
email=f"{REFUND_TEST_USER_ID}@example.com", "email": f"{REFUND_TEST_USER_ID}@example.com",
name="Refund Test User", "name": "Refund Test User",
) }
) )
# Create user balance # Create user balance
await UserBalance.prisma().create( await UserBalance.prisma().create(
data=UserBalanceCreateInput( data={
userId=REFUND_TEST_USER_ID, "userId": REFUND_TEST_USER_ID,
balance=1000, # $10 "balance": 1000, # $10
) }
) )
# Create a top-up transaction that can be refunded # Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create( topup_tx = await CreditTransaction.prisma().create(
data=CreditTransactionCreateInput( data={
userId=REFUND_TEST_USER_ID, "userId": REFUND_TEST_USER_ID,
amount=1000, "amount": 1000,
type=CreditTransactionType.TOP_UP, "type": CreditTransactionType.TOP_UP,
transactionKey="pi_test_12345", "transactionKey": "pi_test_12345",
runningBalance=1000, "runningBalance": 1000,
isActive=True, "isActive": True,
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}), "metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
) }
) )
return topup_tx return topup_tx
@@ -99,12 +93,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
# Create refund request record (simulating webhook flow) # Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create( await CreditRefundRequest.prisma().create(
data=CreditRefundRequestCreateInput( data={
userId=REFUND_TEST_USER_ID, "userId": REFUND_TEST_USER_ID,
amount=500, "amount": 500,
transactionKey=topup_tx.transactionKey, # Should match the original transaction "transactionKey": topup_tx.transactionKey, # Should match the original transaction
reason="Test refund", "reason": "Test refund",
) }
) )
# Call deduct_credits # Call deduct_credits
@@ -292,12 +286,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
refund_requests = [] refund_requests = []
for i in range(5): for i in range(5):
req = await CreditRefundRequest.prisma().create( req = await CreditRefundRequest.prisma().create(
data=CreditRefundRequestCreateInput( data={
userId=REFUND_TEST_USER_ID, "userId": REFUND_TEST_USER_ID,
amount=100, # $1 each "amount": 100, # $1 each
transactionKey=topup_tx.transactionKey, "transactionKey": topup_tx.transactionKey,
reason=f"Test refund {i}", "reason": f"Test refund {i}",
) }
) )
refund_requests.append(req) refund_requests.append(req)

View File

@@ -3,11 +3,6 @@ from datetime import datetime, timedelta, timezone
import pytest import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance from prisma.models import CreditTransaction, UserBalance
from prisma.types import (
CreditTransactionCreateInput,
UserBalanceCreateInput,
UserBalanceUpsertInput,
)
from backend.blocks.llm import AITextGeneratorBlock from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block from backend.data.block import get_block
@@ -28,10 +23,10 @@ async def disable_test_user_transactions():
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID}, where={"userId": DEFAULT_USER_ID},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0), "create": {"userId": DEFAULT_USER_ID, "balance": 0},
update={"balance": 0, "updatedAt": old_date}, "update": {"balance": 0, "updatedAt": old_date},
), },
) )
@@ -145,23 +140,23 @@ async def test_block_credit_reset(server: SpinTestServer):
# Manually create a transaction with month 1 timestamp to establish history # Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create( await CreditTransaction.prisma().create(
data=CreditTransactionCreateInput( data={
userId=DEFAULT_USER_ID, "userId": DEFAULT_USER_ID,
amount=100, "amount": 100,
type=CreditTransactionType.TOP_UP, "type": CreditTransactionType.TOP_UP,
runningBalance=1100, "runningBalance": 1100,
isActive=True, "isActive": True,
createdAt=month1, # Set specific timestamp "createdAt": month1, # Set specific timestamp
) }
) )
# Update user balance to match # Update user balance to match
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID}, where={"userId": DEFAULT_USER_ID},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100), "create": {"userId": DEFAULT_USER_ID, "balance": 1100},
update={"balance": 1100}, "update": {"balance": 1100},
), },
) )
# Now test month 2 behavior # Now test month 2 behavior
@@ -180,14 +175,14 @@ async def test_block_credit_reset(server: SpinTestServer):
# Create a month 2 transaction to update the last transaction time # Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create( await CreditTransaction.prisma().create(
data=CreditTransactionCreateInput( data={
userId=DEFAULT_USER_ID, "userId": DEFAULT_USER_ID,
amount=-700, # Spent 700 to get to 400 "amount": -700, # Spent 700 to get to 400
type=CreditTransactionType.USAGE, "type": CreditTransactionType.USAGE,
runningBalance=400, "runningBalance": 400,
isActive=True, "isActive": True,
createdAt=month2, "createdAt": month2,
) }
) )
# Move to month 3 # Move to month 3

View File

@@ -12,7 +12,6 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
from backend.data.credit import POSTGRES_INT_MIN, UserCredit from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer from backend.util.test import SpinTestServer
@@ -22,11 +21,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests.""" """Create a test user for underflow tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=user_id, "id": user_id,
email=f"test-{user_id}@example.com", "email": f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}", "name": f"Test User {user_id[:8]}",
) }
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -34,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
create=UserBalanceCreateInput(userId=user_id, balance=0),
update={"balance": 0},
),
) )
@@ -70,14 +66,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
initial_balance_target = POSTGRES_INT_MIN + 100 initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow # Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput( "create": {"userId": user_id, "balance": initial_balance_target},
userId=user_id, balance=initial_balance_target "update": {"balance": initial_balance_target},
), },
update={"balance": initial_balance_target},
),
) )
current_balance = await credit_system.get_credits(user_id) current_balance = await credit_system.get_credits(user_id)
@@ -114,10 +110,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
# Set balance to exactly POSTGRES_INT_MIN # Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN), "create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
update={"balance": POSTGRES_INT_MIN}, "update": {"balance": POSTGRES_INT_MIN},
), },
) )
edge_balance = await credit_system.get_credits(user_id) edge_balance = await credit_system.get_credits(user_id)
@@ -151,13 +147,15 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
# Set up balance close to underflow threshold to test the protection # Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000 # Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection # This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000 test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=user_id, balance=test_balance), "create": {"userId": user_id, "balance": test_balance},
update={"balance": test_balance}, "update": {"balance": test_balance},
), },
) )
current_balance = await credit_system.get_credits(user_id) current_balance = await credit_system.get_credits(user_id)
@@ -214,13 +212,15 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
try: try:
# Set up balance close to underflow threshold # Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance), "create": {"userId": user_id, "balance": initial_balance},
update={"balance": initial_balance}, "update": {"balance": initial_balance},
), },
) )
# Apply multiple refunds that would cumulatively underflow # Apply multiple refunds that would cumulatively underflow
@@ -290,13 +290,15 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
try: try:
# Set up balance close to underflow threshold # Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert( await UserBalance.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserBalanceUpsertInput( data={
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance), "create": {"userId": user_id, "balance": initial_balance},
update={"balance": initial_balance}, "update": {"balance": initial_balance},
), },
) )
async def large_refund(amount: int, label: str): async def large_refund(amount: int, label: str):

View File

@@ -14,7 +14,6 @@ import pytest
from prisma.enums import CreditTransactionType from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance from prisma.models import CreditTransaction, User, UserBalance
from prisma.types import UserBalanceCreateInput, UserCreateInput
from backend.data.credit import UsageTransactionMetadata, UserCredit from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson from backend.util.json import SafeJson
@@ -25,11 +24,11 @@ async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests.""" """Create a test user for migration tests."""
try: try:
await User.prisma().create( await User.prisma().create(
data=UserCreateInput( data={
id=user_id, "id": user_id,
email=f"test-{user_id}@example.com", "email": f"test-{user_id}@example.com",
name=f"Test User {user_id[:8]}", "name": f"Test User {user_id[:8]}",
) }
) )
except UniqueViolationError: except UniqueViolationError:
# User already exists, continue # User already exists, continue
@@ -122,7 +121,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
try: try:
# Create UserBalance with specific value # Create UserBalance with specific value
await UserBalance.prisma().create( await UserBalance.prisma().create(
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50 data={"userId": user_id, "balance": 5000} # $50
) )
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value # Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
@@ -161,9 +160,7 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
try: try:
# Set initial balance in UserBalance # Set initial balance in UserBalance
await UserBalance.prisma().create( await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
data=UserBalanceCreateInput(userId=user_id, balance=1000)
)
# Run concurrent operations to ensure they all use UserBalance atomic operations # Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str): async def concurrent_spend(amount: int, label: str):

View File

@@ -1,10 +1,10 @@
import logging import logging
import queue
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import ( from typing import (
TYPE_CHECKING,
Annotated, Annotated,
Any, Any,
AsyncGenerator, AsyncGenerator,
@@ -27,7 +27,6 @@ from prisma.models import (
AgentNodeExecutionKeyValueData, AgentNodeExecutionKeyValueData,
) )
from prisma.types import ( from prisma.types import (
AgentGraphExecutionCreateInput,
AgentGraphExecutionUpdateManyMutationInput, AgentGraphExecutionUpdateManyMutationInput,
AgentGraphExecutionWhereInput, AgentGraphExecutionWhereInput,
AgentNodeExecutionCreateInput, AgentNodeExecutionCreateInput,
@@ -35,7 +34,7 @@ from prisma.types import (
AgentNodeExecutionKeyValueDataCreateInput, AgentNodeExecutionKeyValueDataCreateInput,
AgentNodeExecutionUpdateInput, AgentNodeExecutionUpdateInput,
AgentNodeExecutionWhereInput, AgentNodeExecutionWhereInput,
_AgentNodeExecutionWhereUnique_id_Input, AgentNodeExecutionWhereUniqueInput,
) )
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
@@ -66,19 +65,15 @@ from .includes import (
) )
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
pass
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
config = Config() config = Config()
class GrantResolverContext(BaseModel):
"""Context for grant-based credential resolution in external API executions."""
client_db_id: str # The OAuth client database UUID
grant_ids: list[str] # List of grant IDs to use for credential resolution
class ExecutionContext(BaseModel): class ExecutionContext(BaseModel):
""" """
Unified context that carries execution-level data throughout the entire execution flow. Unified context that carries execution-level data throughout the entire execution flow.
@@ -89,8 +84,6 @@ class ExecutionContext(BaseModel):
user_timezone: str = "UTC" user_timezone: str = "UTC"
root_execution_id: Optional[str] = None root_execution_id: Optional[str] = None
parent_execution_id: Optional[str] = None parent_execution_id: Optional[str] = None
# For external API executions using credential grants
grant_resolver_context: Optional[GrantResolverContext] = None
# -------------------------- Models -------------------------- # # -------------------------- Models -------------------------- #
@@ -715,18 +708,18 @@ async def create_graph_execution(
The id of the AgentGraphExecution and the list of ExecutionResult for each node. The id of the AgentGraphExecution and the list of ExecutionResult for each node.
""" """
result = await AgentGraphExecution.prisma().create( result = await AgentGraphExecution.prisma().create(
data=AgentGraphExecutionCreateInput( data={
agentGraphId=graph_id, "agentGraphId": graph_id,
agentGraphVersion=graph_version, "agentGraphVersion": graph_version,
executionStatus=ExecutionStatus.INCOMPLETE, "executionStatus": ExecutionStatus.INCOMPLETE,
inputs=SafeJson(inputs), "inputs": SafeJson(inputs),
credentialInputs=( "credentialInputs": (
SafeJson(credential_inputs) if credential_inputs else Json({}) SafeJson(credential_inputs) if credential_inputs else Json({})
), ),
nodesInputMasks=( "nodesInputMasks": (
SafeJson(nodes_input_masks) if nodes_input_masks else Json({}) SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
), ),
NodeExecutions={ "NodeExecutions": {
"create": [ "create": [
AgentNodeExecutionCreateInput( AgentNodeExecutionCreateInput(
agentNodeId=node_id, agentNodeId=node_id,
@@ -742,10 +735,10 @@ async def create_graph_execution(
for node_id, node_input in starting_nodes_input for node_id, node_input in starting_nodes_input
] ]
}, },
userId=user_id, "userId": user_id,
agentPresetId=preset_id, "agentPresetId": preset_id,
parentGraphExecutionId=parent_graph_exec_id, "parentGraphExecutionId": parent_graph_exec_id,
), },
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES, include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
) )
@@ -837,15 +830,39 @@ async def upsert_execution_output(
""" """
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output. Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
""" """
data = AgentNodeExecutionInputOutputCreateInput( data: AgentNodeExecutionInputOutputCreateInput = {
name=output_name, "name": output_name,
referencedByOutputExecId=node_exec_id, "referencedByOutputExecId": node_exec_id,
) }
if output_data is not None: if output_data is not None:
data["data"] = SafeJson(output_data) data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data) await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def get_execution_outputs_by_node_exec_id(
node_exec_id: str,
) -> dict[str, Any]:
"""
Get all execution outputs for a specific node execution ID.
Args:
node_exec_id: The node execution ID to get outputs for
Returns:
Dictionary mapping output names to their data values
"""
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
return result
async def update_graph_execution_start_time( async def update_graph_execution_start_time(
graph_exec_id: str, graph_exec_id: str,
) -> GraphExecution | None: ) -> GraphExecution | None:
@@ -958,7 +975,7 @@ async def update_node_execution_status(
if res := await AgentNodeExecution.prisma().update( if res := await AgentNodeExecution.prisma().update(
where=cast( where=cast(
_AgentNodeExecutionWhereUnique_id_Input, AgentNodeExecutionWhereUniqueInput,
{ {
"id": node_exec_id, "id": node_exec_id,
"executionStatus": {"in": [s.value for s in allowed_from]}, "executionStatus": {"in": [s.value for s in allowed_from]},
@@ -1146,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]): class ExecutionQueue(Generic[T]):
""" """
Queue for managing the execution of agents. Thread-safe queue for managing node execution within a single graph execution.
This will be shared between different processes
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
""" """
def __init__(self): def __init__(self):
self.queue = Manager().Queue() # Thread-safe queue (not multiprocessing) — see class docstring
self.queue: queue.Queue[T] = queue.Queue()
def add(self, execution: T) -> T: def add(self, execution: T) -> T:
self.queue.put(execution) self.queue.put(execution)
@@ -1166,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None: def get_or_none(self) -> T | None:
try: try:
return self.queue.get_nowait() return self.queue.get_nowait()
except Empty: except queue.Empty:
return None return None

View File

@@ -0,0 +1,60 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
import pytest
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -10,11 +10,7 @@ from typing import Optional
from prisma.enums import ReviewStatus from prisma.enums import ReviewStatus
from prisma.models import PendingHumanReview from prisma.models import PendingHumanReview
from prisma.types import ( from prisma.types import PendingHumanReviewUpdateInput
PendingHumanReviewCreateInput,
PendingHumanReviewUpdateInput,
PendingHumanReviewUpsertInput,
)
from pydantic import BaseModel from pydantic import BaseModel
from backend.server.v2.executions.review.model import ( from backend.server.v2.executions.review.model import (
@@ -70,20 +66,20 @@ async def get_or_create_human_review(
# Upsert - get existing or create new review # Upsert - get existing or create new review
review = await PendingHumanReview.prisma().upsert( review = await PendingHumanReview.prisma().upsert(
where={"nodeExecId": node_exec_id}, where={"nodeExecId": node_exec_id},
data=PendingHumanReviewUpsertInput( data={
create=PendingHumanReviewCreateInput( "create": {
userId=user_id, "userId": user_id,
nodeExecId=node_exec_id, "nodeExecId": node_exec_id,
graphExecId=graph_exec_id, "graphExecId": graph_exec_id,
graphId=graph_id, "graphId": graph_id,
graphVersion=graph_version, "graphVersion": graph_version,
payload=SafeJson(input_data), "payload": SafeJson(input_data),
instructions=message, "instructions": message,
editable=editable, "editable": editable,
status=ReviewStatus.WAITING, "status": ReviewStatus.WAITING,
), },
update={}, # Do nothing on update - keep existing review as is "update": {}, # Do nothing on update - keep existing review as is
), },
) )
logger.info( logger.info(

View File

@@ -1,302 +0,0 @@
"""
Integration scopes mapping.
Maps AutoGPT's fine-grained integration scopes to provider-specific OAuth scopes.
These scopes are used to request granular permissions when connecting integrations
through the Credential Broker.
"""
from enum import Enum
from typing import Optional
from backend.integrations.providers import ProviderName
class IntegrationScope(str, Enum):
"""
Fine-grained integration scopes for credential grants.
Format: {provider}:{resource}.{permission}
"""
# Google scopes
GOOGLE_EMAIL_READ = "google:email.read"
GOOGLE_GMAIL_READONLY = "google:gmail.readonly"
GOOGLE_GMAIL_SEND = "google:gmail.send"
GOOGLE_GMAIL_MODIFY = "google:gmail.modify"
GOOGLE_DRIVE_READONLY = "google:drive.readonly"
GOOGLE_DRIVE_FILE = "google:drive.file"
GOOGLE_CALENDAR_READONLY = "google:calendar.readonly"
GOOGLE_CALENDAR_EVENTS = "google:calendar.events"
GOOGLE_SHEETS_READONLY = "google:sheets.readonly"
GOOGLE_SHEETS = "google:sheets"
GOOGLE_DOCS_READONLY = "google:docs.readonly"
GOOGLE_DOCS = "google:docs"
# GitHub scopes
GITHUB_REPOS_READ = "github:repos.read"
GITHUB_REPOS_WRITE = "github:repos.write"
GITHUB_ISSUES_READ = "github:issues.read"
GITHUB_ISSUES_WRITE = "github:issues.write"
GITHUB_USER_READ = "github:user.read"
GITHUB_GISTS = "github:gists"
GITHUB_NOTIFICATIONS = "github:notifications"
# Discord scopes
DISCORD_IDENTIFY = "discord:identify"
DISCORD_EMAIL = "discord:email"
DISCORD_GUILDS = "discord:guilds"
DISCORD_MESSAGES_READ = "discord:messages.read"
# Twitter scopes
TWITTER_READ = "twitter:read"
TWITTER_WRITE = "twitter:write"
TWITTER_DM = "twitter:dm"
# Notion scopes
NOTION_READ = "notion:read"
NOTION_WRITE = "notion:write"
# Todoist scopes
TODOIST_READ = "todoist:read"
TODOIST_WRITE = "todoist:write"
# Scope descriptions for consent UI
INTEGRATION_SCOPE_DESCRIPTIONS: dict[str, str] = {
# Google
IntegrationScope.GOOGLE_EMAIL_READ.value: "Read your email address",
IntegrationScope.GOOGLE_GMAIL_READONLY.value: "Read your Gmail messages",
IntegrationScope.GOOGLE_GMAIL_SEND.value: "Send emails on your behalf",
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: "Read, send, and manage your emails",
IntegrationScope.GOOGLE_DRIVE_READONLY.value: "View files in your Google Drive",
IntegrationScope.GOOGLE_DRIVE_FILE.value: "Create and edit files in Google Drive",
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: "View your calendar",
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: "Create and edit calendar events",
IntegrationScope.GOOGLE_SHEETS_READONLY.value: "View your spreadsheets",
IntegrationScope.GOOGLE_SHEETS.value: "Create and edit spreadsheets",
IntegrationScope.GOOGLE_DOCS_READONLY.value: "View your documents",
IntegrationScope.GOOGLE_DOCS.value: "Create and edit documents",
# GitHub
IntegrationScope.GITHUB_REPOS_READ.value: "Read repository information",
IntegrationScope.GITHUB_REPOS_WRITE.value: "Create and manage repositories",
IntegrationScope.GITHUB_ISSUES_READ.value: "Read issues and pull requests",
IntegrationScope.GITHUB_ISSUES_WRITE.value: "Create and manage issues",
IntegrationScope.GITHUB_USER_READ.value: "Read your GitHub profile",
IntegrationScope.GITHUB_GISTS.value: "Create and manage gists",
IntegrationScope.GITHUB_NOTIFICATIONS.value: "Access notifications",
# Discord
IntegrationScope.DISCORD_IDENTIFY.value: "Access your Discord username",
IntegrationScope.DISCORD_EMAIL.value: "Access your Discord email",
IntegrationScope.DISCORD_GUILDS.value: "View your server list",
IntegrationScope.DISCORD_MESSAGES_READ.value: "Read messages",
# Twitter
IntegrationScope.TWITTER_READ.value: "Read tweets and profile",
IntegrationScope.TWITTER_WRITE.value: "Post tweets on your behalf",
IntegrationScope.TWITTER_DM.value: "Send and read direct messages",
# Notion
IntegrationScope.NOTION_READ.value: "View Notion pages",
IntegrationScope.NOTION_WRITE.value: "Create and edit Notion pages",
# Todoist
IntegrationScope.TODOIST_READ.value: "View your tasks",
IntegrationScope.TODOIST_WRITE.value: "Create and manage tasks",
}
# Mapping from integration scopes to provider OAuth scopes
INTEGRATION_SCOPE_MAPPING: dict[str, dict[str, list[str]]] = {
ProviderName.GOOGLE.value: {
IntegrationScope.GOOGLE_EMAIL_READ.value: [
"https://www.googleapis.com/auth/userinfo.email",
"openid",
],
IntegrationScope.GOOGLE_GMAIL_READONLY.value: [
"https://www.googleapis.com/auth/gmail.readonly",
],
IntegrationScope.GOOGLE_GMAIL_SEND.value: [
"https://www.googleapis.com/auth/gmail.send",
],
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: [
"https://www.googleapis.com/auth/gmail.modify",
],
IntegrationScope.GOOGLE_DRIVE_READONLY.value: [
"https://www.googleapis.com/auth/drive.readonly",
],
IntegrationScope.GOOGLE_DRIVE_FILE.value: [
"https://www.googleapis.com/auth/drive.file",
],
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: [
"https://www.googleapis.com/auth/calendar.readonly",
],
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: [
"https://www.googleapis.com/auth/calendar.events",
],
IntegrationScope.GOOGLE_SHEETS_READONLY.value: [
"https://www.googleapis.com/auth/spreadsheets.readonly",
],
IntegrationScope.GOOGLE_SHEETS.value: [
"https://www.googleapis.com/auth/spreadsheets",
],
IntegrationScope.GOOGLE_DOCS_READONLY.value: [
"https://www.googleapis.com/auth/documents.readonly",
],
IntegrationScope.GOOGLE_DOCS.value: [
"https://www.googleapis.com/auth/documents",
],
},
ProviderName.GITHUB.value: {
IntegrationScope.GITHUB_REPOS_READ.value: [
"repo:status",
"public_repo",
],
IntegrationScope.GITHUB_REPOS_WRITE.value: [
"repo",
],
IntegrationScope.GITHUB_ISSUES_READ.value: [
"repo:status",
],
IntegrationScope.GITHUB_ISSUES_WRITE.value: [
"repo",
],
IntegrationScope.GITHUB_USER_READ.value: [
"read:user",
"user:email",
],
IntegrationScope.GITHUB_GISTS.value: [
"gist",
],
IntegrationScope.GITHUB_NOTIFICATIONS.value: [
"notifications",
],
},
ProviderName.DISCORD.value: {
IntegrationScope.DISCORD_IDENTIFY.value: [
"identify",
],
IntegrationScope.DISCORD_EMAIL.value: [
"email",
],
IntegrationScope.DISCORD_GUILDS.value: [
"guilds",
],
IntegrationScope.DISCORD_MESSAGES_READ.value: [
"messages.read",
],
},
ProviderName.TWITTER.value: {
IntegrationScope.TWITTER_READ.value: [
"tweet.read",
"users.read",
],
IntegrationScope.TWITTER_WRITE.value: [
"tweet.write",
],
IntegrationScope.TWITTER_DM.value: [
"dm.read",
"dm.write",
],
},
ProviderName.NOTION.value: {
IntegrationScope.NOTION_READ.value: [], # Notion uses workspace-level access
IntegrationScope.NOTION_WRITE.value: [],
},
ProviderName.TODOIST.value: {
IntegrationScope.TODOIST_READ.value: [
"data:read",
],
IntegrationScope.TODOIST_WRITE.value: [
"data:read_write",
],
},
}
def get_provider_scopes(
provider: ProviderName | str, integration_scopes: list[str]
) -> list[str]:
"""
Convert integration scopes to provider-specific OAuth scopes.
Args:
provider: The provider name
integration_scopes: List of integration scope strings
Returns:
List of provider-specific OAuth scopes
"""
provider_value = provider.value if isinstance(provider, ProviderName) else provider
provider_mapping = INTEGRATION_SCOPE_MAPPING.get(provider_value, {})
oauth_scopes: set[str] = set()
for scope in integration_scopes:
if scope in provider_mapping:
oauth_scopes.update(provider_mapping[scope])
return list(oauth_scopes)
def get_provider_for_scope(scope: str) -> Optional[ProviderName]:
"""
Get the provider for an integration scope.
Args:
scope: Integration scope string (e.g., "google:gmail.readonly")
Returns:
ProviderName or None if not recognized
"""
if ":" not in scope:
return None
provider_prefix = scope.split(":")[0]
# Map prefixes to providers
prefix_mapping = {
"google": ProviderName.GOOGLE,
"github": ProviderName.GITHUB,
"discord": ProviderName.DISCORD,
"twitter": ProviderName.TWITTER,
"notion": ProviderName.NOTION,
"todoist": ProviderName.TODOIST,
}
return prefix_mapping.get(provider_prefix)
def validate_integration_scopes(scopes: list[str]) -> tuple[bool, list[str]]:
"""
Validate a list of integration scopes.
Args:
scopes: List of integration scope strings
Returns:
Tuple of (valid, invalid_scopes)
"""
valid_scopes = {s.value for s in IntegrationScope}
invalid = [s for s in scopes if s not in valid_scopes]
return len(invalid) == 0, invalid
def group_scopes_by_provider(
scopes: list[str],
) -> dict[ProviderName, list[str]]:
"""
Group integration scopes by their provider.
Args:
scopes: List of integration scope strings
Returns:
Dictionary mapping providers to their scopes
"""
grouped: dict[ProviderName, list[str]] = {}
for scope in scopes:
provider = get_provider_for_scope(scope)
if provider:
if provider not in grouped:
grouped[provider] = []
grouped[provider].append(scope)
return grouped

View File

@@ -1,176 +0,0 @@
"""
OAuth Audit Logging.
Logs all OAuth-related operations for security auditing and compliance.
"""
import logging
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Optional
from backend.data.db import prisma
logger = logging.getLogger(__name__)
class OAuthEventType(str, Enum):
"""Types of OAuth events to audit."""
# Client events
CLIENT_REGISTERED = "client.registered"
CLIENT_UPDATED = "client.updated"
CLIENT_DELETED = "client.deleted"
CLIENT_SECRET_ROTATED = "client.secret_rotated"
CLIENT_SUSPENDED = "client.suspended"
CLIENT_ACTIVATED = "client.activated"
# Authorization events
AUTHORIZATION_REQUESTED = "authorization.requested"
AUTHORIZATION_GRANTED = "authorization.granted"
AUTHORIZATION_DENIED = "authorization.denied"
AUTHORIZATION_REVOKED = "authorization.revoked"
# Token events
TOKEN_ISSUED = "token.issued"
TOKEN_REFRESHED = "token.refreshed"
TOKEN_REVOKED = "token.revoked"
TOKEN_EXPIRED = "token.expired"
# Grant events
GRANT_CREATED = "grant.created"
GRANT_UPDATED = "grant.updated"
GRANT_REVOKED = "grant.revoked"
GRANT_USED = "grant.used"
# Credential events
CREDENTIAL_CONNECTED = "credential.connected"
CREDENTIAL_DELETED = "credential.deleted"
# Execution events
EXECUTION_STARTED = "execution.started"
EXECUTION_COMPLETED = "execution.completed"
EXECUTION_FAILED = "execution.failed"
EXECUTION_CANCELLED = "execution.cancelled"
async def log_oauth_event(
event_type: OAuthEventType,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
grant_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[dict[str, Any]] = None,
) -> str:
"""
Log an OAuth audit event.
Args:
event_type: Type of event
user_id: User ID involved (if any)
client_id: OAuth client ID involved (if any)
grant_id: Grant ID involved (if any)
ip_address: Client IP address
user_agent: Client user agent
details: Additional event details
Returns:
ID of the created audit log entry
"""
try:
from prisma import Json
audit_entry = await prisma.oauthauditlog.create(
data={ # type: ignore[typeddict-item]
"eventType": event_type.value,
"userId": user_id,
"clientId": client_id,
"grantId": grant_id,
"ipAddress": ip_address,
"userAgent": user_agent,
"details": Json(details or {}),
}
)
logger.debug(
f"OAuth audit: {event_type.value} - "
f"user={user_id}, client={client_id}, grant={grant_id}"
)
return audit_entry.id
except Exception as e:
# Log but don't fail the operation if audit logging fails
logger.error(f"Failed to create OAuth audit log: {e}")
return ""
async def get_audit_logs(
user_id: Optional[str] = None,
client_id: Optional[str] = None,
event_type: Optional[OAuthEventType] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = 100,
offset: int = 0,
) -> list:
"""
Query OAuth audit logs.
Args:
user_id: Filter by user ID
client_id: Filter by client ID
event_type: Filter by event type
start_date: Filter by start date
end_date: Filter by end date
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of audit log entries
"""
where: dict[str, Any] = {}
if user_id:
where["userId"] = user_id
if client_id:
where["clientId"] = client_id
if event_type:
where["eventType"] = event_type.value
if start_date:
where["createdAt"] = {"gte": start_date}
if end_date:
if "createdAt" in where:
where["createdAt"]["lte"] = end_date
else:
where["createdAt"] = {"lte": end_date}
return await prisma.oauthauditlog.find_many(
where=where if where else None, # type: ignore[arg-type]
order={"createdAt": "desc"},
take=limit,
skip=offset,
)
async def cleanup_old_audit_logs(days_to_keep: int = 90) -> int:
"""
Delete audit logs older than the specified number of days.
Args:
days_to_keep: Number of days of logs to retain
Returns:
Number of logs deleted
"""
from datetime import timedelta
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
result = await prisma.oauthauditlog.delete_many(
where={"createdAt": {"lt": cutoff_date}}
)
logger.info(f"Cleaned up {result} OAuth audit logs older than {days_to_keep} days")
return result

View File

@@ -7,11 +7,7 @@ import prisma
import pydantic import pydantic
from prisma.enums import OnboardingStep from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding from prisma.models import UserOnboarding
from prisma.types import ( from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
UserOnboardingCreateInput,
UserOnboardingUpdateInput,
UserOnboardingUpsertInput,
)
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data.credit import get_user_credit_model from backend.data.credit import get_user_credit_model
@@ -116,10 +112,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
return await UserOnboarding.prisma().upsert( return await UserOnboarding.prisma().upsert(
where={"userId": user_id}, where={"userId": user_id},
data=UserOnboardingUpsertInput( data={
create=UserOnboardingCreateInput(userId=user_id, **update), "create": {"userId": user_id, **update},
update=update, "update": update,
), },
) )

View File

@@ -13,6 +13,7 @@ from backend.data.execution import (
get_block_error_stats, get_block_error_stats,
get_child_graph_executions, get_child_graph_executions,
get_execution_kv_data, get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs, get_frequently_executed_graphs,
get_graph_execution_meta, get_graph_execution_meta,
get_graph_executions, get_graph_executions,
@@ -147,6 +148,7 @@ class DatabaseManager(AppService):
update_graph_execution_stats = _(update_graph_execution_stats) update_graph_execution_stats = _(update_graph_execution_stats)
upsert_execution_input = _(upsert_execution_input) upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output) upsert_execution_output = _(upsert_execution_output)
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
get_execution_kv_data = _(get_execution_kv_data) get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data) set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats) get_block_error_stats = _(get_block_error_stats)
@@ -277,6 +279,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_integrations = d.get_user_integrations get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
update_graph_execution_stats = d.update_graph_execution_stats update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch update_node_execution_status_batch = d.update_node_execution_status_batch

View File

@@ -67,7 +67,6 @@ from backend.executor.utils import (
validate_exec, validate_exec,
) )
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhook_notifier import get_webhook_notifier
from backend.notifications.notifications import queue_notification from backend.notifications.notifications import queue_notification
from backend.server.v2.AutoMod.manager import automod_manager from backend.server.v2.AutoMod.manager import automod_manager
from backend.util import json from backend.util import json
@@ -134,9 +133,8 @@ def execute_graph(
cluster_lock: ClusterLock, cluster_lock: ClusterLock,
): ):
"""Execute graph using thread-local ExecutionProcessor instance""" """Execute graph using thread-local ExecutionProcessor instance"""
return _tls.processor.on_graph_execution( processor: ExecutionProcessor = _tls.processor
graph_exec_entry, cancel_event, cluster_lock return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
)
T = TypeVar("T") T = TypeVar("T")
@@ -144,8 +142,8 @@ T = TypeVar("T")
async def execute_node( async def execute_node(
node: Node, node: Node,
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry, data: NodeExecutionEntry,
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None, execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None, nodes_input_masks: Optional[NodesInputMasks] = None,
) -> BlockOutput: ) -> BlockOutput:
@@ -170,6 +168,7 @@ async def execute_node(
node_id = data.node_id node_id = data.node_id
node_block = node.block node_block = node.block
execution_context = data.execution_context execution_context = data.execution_context
creds_manager = execution_processor.creds_manager
log_metadata = LogMetadata( log_metadata = LogMetadata(
logger=_logger, logger=_logger,
@@ -213,6 +212,7 @@ async def execute_node(
"node_exec_id": node_exec_id, "node_exec_id": node_exec_id,
"user_id": user_id, "user_id": user_id,
"execution_context": execution_context, "execution_context": execution_context,
"execution_processor": execution_processor,
} }
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent # Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -222,30 +222,10 @@ async def execute_node(
creds_locks: list[AsyncRedisLock] = [] creds_locks: list[AsyncRedisLock] = []
input_model = cast(type[BlockSchema], node_block.input_schema) input_model = cast(type[BlockSchema], node_block.input_schema)
# Check if this is an external API execution using grant-based credential resolution
grant_resolver = None
if execution_context and execution_context.grant_resolver_context:
from backend.integrations.grant_resolver import GrantBasedCredentialResolver
grant_ctx = execution_context.grant_resolver_context
grant_resolver = GrantBasedCredentialResolver(
user_id=user_id,
client_id=grant_ctx.client_db_id,
grant_ids=grant_ctx.grant_ids,
)
await grant_resolver.initialize()
# Handle regular credentials fields # Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items(): for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name]) credentials_meta = input_type(**input_data[field_name])
if grant_resolver: credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
# External API execution - use grant resolver (no locking needed)
credentials = await grant_resolver.resolve_credential(credentials_meta.id)
else:
# Normal execution - use credentials manager with locking
credentials, lock = await creds_manager.acquire(
user_id, credentials_meta.id
)
creds_locks.append(lock) creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials extra_exec_kwargs[field_name] = credentials
@@ -264,13 +244,6 @@ async def execute_node(
) )
file_name = field_data.get("name", "selected file") file_name = field_data.get("name", "selected file")
try: try:
if grant_resolver:
# External API execution - use grant resolver
credentials = await grant_resolver.resolve_credential(
cred_id
)
else:
# Normal execution - use credentials manager
credentials, lock = await creds_manager.acquire( credentials, lock = await creds_manager.acquire(
user_id, cred_id user_id, cred_id
) )
@@ -636,8 +609,8 @@ class ExecutionProcessor:
async for output_name, output_data in execute_node( async for output_name, output_data in execute_node(
node=node, node=node,
creds_manager=self.creds_manager,
data=node_exec, data=node_exec,
execution_processor=self,
execution_stats=stats, execution_stats=stats,
nodes_input_masks=nodes_input_masks, nodes_input_masks=nodes_input_masks,
): ):
@@ -813,7 +786,6 @@ class ExecutionProcessor:
graph_exec_id=graph_exec.graph_exec_id, graph_exec_id=graph_exec.graph_exec_id,
status=exec_meta.status, status=exec_meta.status,
stats=exec_stats, stats=exec_stats,
event_loop=self.node_execution_loop,
) )
def _charge_usage( def _charge_usage(
@@ -889,12 +861,17 @@ class ExecutionProcessor:
execution_stats_lock = threading.Lock() execution_stats_lock = threading.Lock()
# State holders ---------------------------------------------------- # State holders ----------------------------------------------------
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict( self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
NodeExecutionProgress NodeExecutionProgress
) )
running_node_evaluation: dict[str, Future] = {} self.running_node_evaluation: dict[str, Future] = {}
self.execution_stats = execution_stats
self.execution_stats_lock = execution_stats_lock
execution_queue = ExecutionQueue[NodeExecutionEntry]() execution_queue = ExecutionQueue[NodeExecutionEntry]()
running_node_execution = self.running_node_execution
running_node_evaluation = self.running_node_evaluation
try: try:
if db_client.get_credits(graph_exec.user_id) <= 0: if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError( raise InsufficientBalanceError(
@@ -1945,53 +1922,6 @@ def update_node_execution_status(
return exec_update return exec_update
async def _notify_execution_webhook(
execution_id: str,
agent_id: str,
status: ExecutionStatus,
outputs: dict[str, Any] | None = None,
error: str | None = None,
) -> None:
"""
Send webhook notification for execution completion if registered.
This is a fire-and-forget operation that checks if a webhook was registered
for this execution and sends the appropriate notification.
"""
from backend.data.db import prisma
try:
webhook = await prisma.executionwebhook.find_first(
where={"executionId": execution_id}
)
if not webhook:
return
notifier = get_webhook_notifier()
if status == ExecutionStatus.COMPLETED:
await notifier.notify_execution_completed(
execution_id=execution_id,
agent_id=agent_id,
client_id=webhook.clientId,
webhook_url=webhook.webhookUrl,
outputs=outputs or {},
webhook_secret=webhook.secret,
)
elif status == ExecutionStatus.FAILED:
await notifier.notify_execution_failed(
execution_id=execution_id,
agent_id=agent_id,
client_id=webhook.clientId,
webhook_url=webhook.webhookUrl,
error=error or "Execution failed",
webhook_secret=webhook.secret,
)
except Exception as e:
# Don't let webhook failures affect execution state updates
logger.warning(f"Failed to send webhook notification for {execution_id}: {e}")
async def async_update_graph_execution_state( async def async_update_graph_execution_state(
db_client: "DatabaseManagerAsyncClient", db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str, graph_exec_id: str,
@@ -2004,17 +1934,6 @@ async def async_update_graph_execution_state(
) )
if graph_update: if graph_update:
await send_async_execution_update(graph_update) await send_async_execution_update(graph_update)
# Send webhook notification for terminal states
if status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED:
await _notify_execution_webhook(
execution_id=graph_exec_id,
agent_id=graph_update.graph_id,
status=status,
outputs=(
graph_update.outputs if hasattr(graph_update, "outputs") else None
),
)
else: else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}") logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update return graph_update
@@ -2025,33 +1944,11 @@ def update_graph_execution_state(
graph_exec_id: str, graph_exec_id: str,
status: ExecutionStatus | None = None, status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None, stats: GraphExecutionStats | None = None,
event_loop: asyncio.AbstractEventLoop | None = None,
) -> GraphExecution | None: ) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution""" """Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats) graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
if graph_update: if graph_update:
send_execution_update(graph_update) send_execution_update(graph_update)
# Send webhook notification for terminal states (fire-and-forget)
if (
status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED
) and event_loop:
try:
asyncio.run_coroutine_threadsafe(
_notify_execution_webhook(
execution_id=graph_exec_id,
agent_id=graph_update.graph_id,
status=status,
outputs=(
graph_update.outputs
if hasattr(graph_update, "outputs")
else None
),
),
event_loop,
)
except Exception as e:
logger.warning(f"Failed to schedule webhook notification: {e}")
else: else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}") logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update return graph_update

View File

@@ -1,278 +0,0 @@
"""
Grant-Based Credential Resolver.
Resolves credentials during agent execution based on credential grants.
External applications can only use credentials they have been granted access to,
and only for the scopes that were granted.
Credentials are NEVER exposed to external applications - this resolver
provides the credentials to the execution engine internally.
"""
import logging
from datetime import datetime, timezone
from typing import Optional
from prisma.enums import CredentialGrantPermission
from prisma.models import CredentialGrant
from backend.data import credential_grants as grants_db
from backend.data.db import prisma
from backend.data.model import Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
logger = logging.getLogger(__name__)
class GrantValidationError(Exception):
"""Raised when a grant is invalid or lacks required permissions."""
pass
class CredentialNotFoundError(Exception):
"""Raised when a credential referenced by a grant is not found."""
pass
class ScopeMismatchError(Exception):
"""Raised when the grant doesn't cover required scopes."""
pass
class GrantBasedCredentialResolver:
"""
Resolves credentials for agent execution based on credential grants.
This resolver validates that:
1. The grant exists and is valid (not revoked/expired)
2. The grant has USE permission
3. The grant covers the required scopes (if specified)
4. The underlying credential exists
Then it provides the credential to the execution engine internally.
The credential value is NEVER exposed to external applications.
"""
def __init__(
self,
user_id: str,
client_id: str,
grant_ids: list[str],
):
"""
Initialize the resolver.
Args:
user_id: User ID who owns the credentials
client_id: Database ID of the OAuth client
grant_ids: List of grant IDs the client is using for this execution
"""
self.user_id = user_id
self.client_id = client_id
self.grant_ids = grant_ids
self._grants: dict[str, CredentialGrant] = {}
self._credentials_manager = IntegrationCredentialsManager()
self._initialized = False
async def initialize(self) -> None:
"""
Load and validate all grants.
This should be called before any credential resolution.
Raises:
GrantValidationError: If any grant is invalid
"""
now = datetime.now(timezone.utc)
for grant_id in self.grant_ids:
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=self.user_id,
client_id=self.client_id,
)
if not grant:
raise GrantValidationError(f"Grant {grant_id} not found")
# Check if revoked
if grant.revokedAt:
raise GrantValidationError(f"Grant {grant_id} has been revoked")
# Check if expired
if grant.expiresAt and grant.expiresAt < now:
raise GrantValidationError(f"Grant {grant_id} has expired")
# Check USE permission
if CredentialGrantPermission.USE not in grant.permissions:
raise GrantValidationError(
f"Grant {grant_id} does not have USE permission"
)
self._grants[grant_id] = grant
self._initialized = True
logger.info(
f"Initialized grant resolver with {len(self._grants)} grants "
f"for user {self.user_id}, client {self.client_id}"
)
async def resolve_credential(
self,
credential_id: str,
required_scopes: Optional[list[str]] = None,
) -> Credentials:
"""
Resolve a credential for agent execution.
This method:
1. Finds a grant that covers this credential
2. Validates the grant covers required scopes
3. Retrieves the actual credential
4. Updates grant usage tracking
Args:
credential_id: ID of the credential to resolve
required_scopes: Optional list of scopes the credential must have
Returns:
The resolved Credentials object
Raises:
GrantValidationError: If no valid grant covers this credential
ScopeMismatchError: If the grant doesn't cover required scopes
CredentialNotFoundError: If the underlying credential doesn't exist
"""
if not self._initialized:
raise RuntimeError("Resolver not initialized. Call initialize() first.")
# Find a grant that covers this credential
matching_grant: Optional[CredentialGrant] = None
for grant in self._grants.values():
if grant.credentialId == credential_id:
matching_grant = grant
break
if not matching_grant:
raise GrantValidationError(f"No grant found for credential {credential_id}")
# Validate scopes if required
if required_scopes:
granted_scopes = set(matching_grant.grantedScopes)
required_scopes_set = set(required_scopes)
missing_scopes = required_scopes_set - granted_scopes
if missing_scopes:
raise ScopeMismatchError(
f"Grant {matching_grant.id} is missing required scopes: "
f"{', '.join(missing_scopes)}"
)
# Get the actual credential
credentials = await self._credentials_manager.get(
user_id=self.user_id,
credentials_id=credential_id,
lock=True,
)
if not credentials:
raise CredentialNotFoundError(
f"Credential {credential_id} not found for user {self.user_id}"
)
# Update last used timestamp for the grant
await grants_db.update_grant_last_used(matching_grant.id)
logger.debug(
f"Resolved credential {credential_id} via grant {matching_grant.id} "
f"for client {self.client_id}"
)
return credentials
async def get_available_credentials(self) -> list[dict]:
"""
Get list of available credentials based on grants.
Returns a list of credential metadata (NOT the actual credential values).
Returns:
List of dicts with credential metadata
"""
if not self._initialized:
raise RuntimeError("Resolver not initialized. Call initialize() first.")
credentials_info = []
for grant in self._grants.values():
credentials_info.append(
{
"grant_id": grant.id,
"credential_id": grant.credentialId,
"provider": grant.provider,
"granted_scopes": grant.grantedScopes,
}
)
return credentials_info
def get_grant_for_credential(self, credential_id: str) -> Optional[CredentialGrant]:
"""
Get the grant for a specific credential.
Args:
credential_id: ID of the credential
Returns:
CredentialGrant or None if not found
"""
for grant in self._grants.values():
if grant.credentialId == credential_id:
return grant
return None
async def create_resolver_from_oauth_token(
user_id: str,
client_public_id: str,
grant_ids: Optional[list[str]] = None,
) -> GrantBasedCredentialResolver:
"""
Create a credential resolver from OAuth token context.
This is a convenience function for creating a resolver from
the context available in OAuth-authenticated requests.
Args:
user_id: User ID from the OAuth token
client_public_id: Public client ID from the OAuth token
grant_ids: Optional list of grant IDs to use
Returns:
Initialized GrantBasedCredentialResolver
"""
# Look up the OAuth client database ID from the public client ID
client = await prisma.oauthclient.find_unique(where={"clientId": client_public_id})
if not client:
raise GrantValidationError(f"OAuth client {client_public_id} not found")
# If no grant IDs specified, get all grants for this client+user
if grant_ids is None:
grants = await grants_db.get_grants_for_user_client(
user_id=user_id,
client_id=client.id,
include_revoked=False,
include_expired=False,
)
grant_ids = [g.id for g in grants]
resolver = GrantBasedCredentialResolver(
user_id=user_id,
client_id=client.id,
grant_ids=grant_ids,
)
await resolver.initialize()
return resolver

View File

@@ -1,331 +0,0 @@
"""
Webhook Notification System for External API.
Sends webhook notifications to external applications for execution events.
"""
import asyncio
import hashlib
import hmac
import json
import logging
import weakref
from datetime import datetime, timezone
from typing import Any, Coroutine, Optional
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
# Webhook delivery settings
WEBHOOK_TIMEOUT_SECONDS = 30
WEBHOOK_MAX_RETRIES = 3
WEBHOOK_RETRY_DELAYS = [5, 30, 300] # seconds: 5s, 30s, 5min
class WebhookDeliveryError(Exception):
"""Raised when webhook delivery fails."""
pass
def sign_webhook_payload(payload: dict[str, Any], secret: str) -> str:
"""
Create HMAC-SHA256 signature for webhook payload.
Args:
payload: The webhook payload to sign
secret: The webhook secret key
Returns:
Hex-encoded HMAC-SHA256 signature
"""
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
signature = hmac.new(
secret.encode(),
payload_bytes,
hashlib.sha256,
).hexdigest()
return signature
def verify_webhook_signature(
payload: dict[str, Any],
signature: str,
secret: str,
) -> bool:
"""
Verify a webhook signature.
Args:
payload: The webhook payload
signature: The signature to verify
secret: The webhook secret key
Returns:
True if signature is valid
"""
expected = sign_webhook_payload(payload, secret)
return hmac.compare_digest(expected, signature)
def validate_webhook_url(url: str, allowed_domains: list[str]) -> bool:
"""
Validate that a webhook URL is allowed.
Args:
url: The webhook URL to validate
allowed_domains: List of allowed domains (from OAuth client config)
Returns:
True if URL is valid and allowed
"""
from backend.util.url import hostname_matches_any_domain
try:
parsed = urlparse(url)
# Must be HTTPS (except for localhost in development)
if parsed.scheme != "https":
if not (
parsed.scheme == "http"
and parsed.hostname in ["localhost", "127.0.0.1"]
):
return False
# Must have a host
if not parsed.hostname:
return False
# Check against allowed domains
return hostname_matches_any_domain(parsed.hostname, allowed_domains)
except Exception:
return False
async def send_webhook(
url: str,
payload: dict[str, Any],
secret: Optional[str] = None,
timeout: int = WEBHOOK_TIMEOUT_SECONDS,
) -> bool:
"""
Send a webhook notification.
Args:
url: Webhook URL
payload: Payload to send
secret: Optional secret for signature
timeout: Request timeout in seconds
Returns:
True if webhook was delivered successfully
"""
headers = {
"Content-Type": "application/json",
"User-Agent": "AutoGPT-Webhook/1.0",
"X-Webhook-Timestamp": datetime.now(timezone.utc).isoformat(),
}
if secret:
signature = sign_webhook_payload(payload, secret)
headers["X-Webhook-Signature"] = f"sha256={signature}"
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url,
json=payload,
headers=headers,
)
if response.status_code >= 200 and response.status_code < 300:
logger.debug(f"Webhook delivered successfully to {url}")
return True
else:
logger.warning(
f"Webhook delivery failed: {url} returned {response.status_code}"
)
return False
except httpx.TimeoutException:
logger.warning(f"Webhook delivery timed out: {url}")
return False
except Exception as e:
logger.error(f"Webhook delivery error: {url} - {str(e)}")
return False
async def send_webhook_with_retry(
url: str,
payload: dict[str, Any],
secret: Optional[str] = None,
max_retries: int = WEBHOOK_MAX_RETRIES,
) -> bool:
"""
Send a webhook with automatic retries.
Args:
url: Webhook URL
payload: Payload to send
secret: Optional secret for signature
max_retries: Maximum number of retry attempts
Returns:
True if webhook was eventually delivered successfully
"""
for attempt in range(max_retries + 1):
if await send_webhook(url, payload, secret):
return True
if attempt < max_retries:
delay = WEBHOOK_RETRY_DELAYS[min(attempt, len(WEBHOOK_RETRY_DELAYS) - 1)]
logger.info(
f"Webhook delivery failed, retrying in {delay}s (attempt {attempt + 1})"
)
await asyncio.sleep(delay)
logger.error(f"Webhook delivery failed after {max_retries} retries: {url}")
return False
# Track pending webhook tasks to prevent garbage collection
# Using WeakSet so tasks are automatically removed when they complete and are dereferenced
_pending_webhook_tasks: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
def _create_tracked_task(coro: Coroutine[Any, Any, bool]) -> asyncio.Task[bool]:
"""Create a task that is tracked to prevent garbage collection."""
task = asyncio.create_task(coro)
_pending_webhook_tasks.add(task)
# No explicit done callback needed - WeakSet automatically removes
# references when tasks are garbage collected after completion
return task
class WebhookNotifier:
"""
Service for sending webhook notifications to external applications.
"""
def __init__(self):
pass
async def notify_execution_started(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has started.
"""
payload = {
"event": "execution.started",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "running",
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_execution_completed(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
outputs: dict[str, Any],
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has completed successfully.
"""
payload = {
"event": "execution.completed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "completed",
"outputs": outputs,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_execution_failed(
self,
execution_id: str,
agent_id: str,
client_id: str,
webhook_url: str,
error: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that an execution has failed.
"""
payload = {
"event": "execution.failed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"execution_id": execution_id,
"agent_id": agent_id,
"status": "failed",
"error": error,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
async def notify_grant_revoked(
self,
grant_id: str,
credential_id: str,
provider: str,
client_id: str,
webhook_url: str,
webhook_secret: Optional[str] = None,
) -> None:
"""
Notify external app that a credential grant has been revoked.
"""
payload = {
"event": "grant.revoked",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"grant_id": grant_id,
"credential_id": credential_id,
"provider": provider,
},
}
_create_tracked_task(
send_webhook_with_retry(webhook_url, payload, webhook_secret)
)
# Module-level singleton
_webhook_notifier: Optional[WebhookNotifier] = None
def get_webhook_notifier() -> WebhookNotifier:
"""Get the singleton webhook notifier instance."""
global _webhook_notifier
if _webhook_notifier is None:
_webhook_notifier = WebhookNotifier()
return _webhook_notifier

View File

@@ -3,19 +3,21 @@ from fastapi import FastAPI
from backend.monitoring.instrumentation import instrument_fastapi from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.middleware.security import SecurityHeadersMiddleware from backend.server.middleware.security import SecurityHeadersMiddleware
from .routes.execution import execution_router from .routes.integrations import integrations_router
from .routes.grants import grants_router from .routes.tools import tools_router
from .routes.v1 import v1_router
external_app = FastAPI( external_app = FastAPI(
title="AutoGPT External API", title="AutoGPT External API",
description="External API for AutoGPT integrations (OAuth-based)", description="External API for AutoGPT integrations",
docs_url="/docs", docs_url="/docs",
version="1.0", version="1.0",
) )
external_app.add_middleware(SecurityHeadersMiddleware) external_app.add_middleware(SecurityHeadersMiddleware)
external_app.include_router(grants_router, prefix="/v1") external_app.include_router(v1_router, prefix="/v1")
external_app.include_router(execution_router, prefix="/v1") external_app.include_router(tools_router, prefix="/v1")
external_app.include_router(integrations_router, prefix="/v1")
# Add Prometheus instrumentation # Add Prometheus instrumentation
instrument_fastapi( instrument_fastapi(

View File

@@ -0,0 +1,36 @@
from fastapi import HTTPException, Security
from fastapi.security import APIKeyHeader
from prisma.enums import APIKeyPermission
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
"""Base middleware for API key authentication"""
if api_key is None:
raise HTTPException(status_code=401, detail="Missing API key")
api_key_obj = await validate_api_key(api_key)
if not api_key_obj:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key_obj
def require_permission(permission: APIKeyPermission):
"""Dependency function for checking specific permissions"""
async def check_permission(
api_key: APIKeyInfo = Security(require_api_key),
) -> APIKeyInfo:
if not has_permission(api_key, permission):
raise HTTPException(
status_code=403,
detail=f"API key lacks the required permission '{permission}'",
)
return api_key
return check_permission

View File

@@ -1,164 +0,0 @@
"""
OAuth Access Token middleware for external API.
Validates OAuth access tokens and provides user/client context
for external API endpoints that use OAuth authentication.
"""
from datetime import datetime, timezone
from typing import Optional
import jwt
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
from backend.data.db import prisma
from backend.server.oauth.token_service import get_token_service
class OAuthTokenInfo(BaseModel):
"""Information extracted from a validated OAuth access token."""
user_id: str
client_id: str
scopes: list[str]
token_id: str
# HTTP Bearer token extractor
oauth_bearer = HTTPBearer(auto_error=False)
async def require_oauth_token(
credentials: Optional[HTTPAuthorizationCredentials] = Security(oauth_bearer),
) -> OAuthTokenInfo:
"""
Validate an OAuth access token and return token info.
Extracts the Bearer token from the Authorization header,
validates the JWT signature and claims, and checks that
the token hasn't been revoked.
Raises:
HTTPException: 401 if token is missing, invalid, or revoked
"""
if credentials is None:
raise HTTPException(
status_code=401,
detail="Missing authorization token",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
token_service = get_token_service()
try:
# Verify JWT signature and claims
claims = token_service.verify_access_token(token)
# Check if token is in database and not revoked
token_hash = token_service.hash_token(token)
stored_token = await prisma.oauthaccesstoken.find_unique(
where={"tokenHash": token_hash}
)
if not stored_token:
raise HTTPException(
status_code=401,
detail="Token not found",
headers={"WWW-Authenticate": "Bearer"},
)
if stored_token.revokedAt:
raise HTTPException(
status_code=401,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
if stored_token.expiresAt < datetime.now(timezone.utc):
raise HTTPException(
status_code=401,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last used timestamp (fire and forget)
await prisma.oauthaccesstoken.update(
where={"id": stored_token.id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
return OAuthTokenInfo(
user_id=claims.sub,
client_id=claims.client_id,
scopes=claims.scope.split() if claims.scope else [],
token_id=stored_token.id,
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=401,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
except jwt.InvalidTokenError as e:
raise HTTPException(
status_code=401,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
def require_scope(required_scope: str):
"""
Dependency that validates OAuth token and checks for required scope.
Args:
required_scope: The scope required for this endpoint
Returns:
Dependency function that returns OAuthTokenInfo if authorized
"""
async def check_scope(
token: OAuthTokenInfo = Security(require_oauth_token),
) -> OAuthTokenInfo:
if required_scope not in token.scopes:
raise HTTPException(
status_code=403,
detail=f"Token lacks required scope '{required_scope}'",
headers={"WWW-Authenticate": f'Bearer scope="{required_scope}"'},
)
return token
return check_scope
def require_any_scope(*required_scopes: str):
"""
Dependency that validates OAuth token and checks for any of the required scopes.
Args:
required_scopes: At least one of these scopes is required
Returns:
Dependency function that returns OAuthTokenInfo if authorized
"""
async def check_scopes(
token: OAuthTokenInfo = Security(require_oauth_token),
) -> OAuthTokenInfo:
for scope in required_scopes:
if scope in token.scopes:
return token
scope_list = " ".join(required_scopes)
raise HTTPException(
status_code=403,
detail=f"Token lacks required scopes (need one of: {scope_list})",
headers={"WWW-Authenticate": f'Bearer scope="{scope_list}"'},
)
return check_scopes

View File

@@ -1,377 +0,0 @@
"""
Agent Execution endpoints for external OAuth clients.
Allows external applications to:
- Execute agents using granted credentials
- Poll execution status
- Cancel running executions
- Get available capabilities
External apps can only use credentials they have been granted access to.
"""
import logging
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import AgentExecutionStatus
from pydantic import BaseModel, Field
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.db import prisma
from backend.data.execution import ExecutionContext, GrantResolverContext
from backend.executor.utils import add_graph_execution
from backend.integrations.grant_resolver import (
GrantValidationError,
create_resolver_from_oauth_token,
)
from backend.integrations.webhook_notifier import validate_webhook_url
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
logger = logging.getLogger(__name__)
execution_router = APIRouter(prefix="/executions", tags=["executions"])
# ================================================================
# Request/Response Models
# ================================================================
class ExecuteAgentRequest(BaseModel):
"""Request to execute an agent."""
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Input values for the agent",
)
grant_ids: Optional[list[str]] = Field(
default=None,
description="Specific grant IDs to use. If not provided, uses all available grants.",
)
webhook_url: Optional[str] = Field(
default=None,
description="URL to receive execution status webhooks",
)
class ExecuteAgentResponse(BaseModel):
"""Response from starting an agent execution."""
execution_id: str
status: str
message: str
class ExecutionStatusResponse(BaseModel):
"""Response with execution status."""
execution_id: str
status: str
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
outputs: Optional[dict[str, Any]] = None
error: Optional[str] = None
class GrantInfo(BaseModel):
"""Summary of a credential grant for capabilities."""
grant_id: str
provider: str
scopes: list[str]
class CapabilitiesResponse(BaseModel):
"""Response describing what the client can do."""
user_id: str
client_id: str
grants: list[GrantInfo]
available_scopes: list[str]
# ================================================================
# Endpoints
# ================================================================
@execution_router.get("/capabilities", response_model=CapabilitiesResponse)
async def get_capabilities(
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> CapabilitiesResponse:
"""
Get the capabilities available to this client for the authenticated user.
Returns information about:
- Available credential grants (NOT credential values)
- Scopes the client has access to
"""
try:
resolver = await create_resolver_from_oauth_token(
user_id=token.user_id,
client_public_id=token.client_id,
)
credentials_info = await resolver.get_available_credentials()
grants = [
GrantInfo(
grant_id=info["grant_id"],
provider=info["provider"],
scopes=info["granted_scopes"],
)
for info in credentials_info
]
return CapabilitiesResponse(
user_id=token.user_id,
client_id=token.client_id,
grants=grants,
available_scopes=token.scopes,
)
except GrantValidationError:
# No grants available is not an error, just empty capabilities
return CapabilitiesResponse(
user_id=token.user_id,
client_id=token.client_id,
grants=[],
available_scopes=token.scopes,
)
@execution_router.post(
"/agents/{agent_id}/execute",
response_model=ExecuteAgentResponse,
)
async def execute_agent(
agent_id: str,
request: ExecuteAgentRequest,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> ExecuteAgentResponse:
"""
Execute an agent using granted credentials.
The agent must be accessible to the user, and the client must have
valid credential grants that satisfy the agent's requirements.
Args:
agent_id: The agent (graph) ID to execute
request: Execution parameters including inputs and optional grant IDs
"""
# Verify the agent exists and user has access
# First try to get the latest version
graph = await graph_db.get_graph(
graph_id=agent_id,
version=None,
user_id=token.user_id,
)
if not graph:
# Try to find it in the store (public agents)
graph = await graph_db.get_graph(
graph_id=agent_id,
version=None,
user_id=None,
skip_access_check=True,
)
if not graph:
raise HTTPException(
status_code=404,
detail=f"Agent {agent_id} not found or not accessible",
)
# Initialize the grant resolver to validate grants exist
# The resolver context will be passed to the execution engine
grant_resolver_context = None
try:
resolver = await create_resolver_from_oauth_token(
user_id=token.user_id,
client_public_id=token.client_id,
grant_ids=request.grant_ids,
)
# Get available credentials info to build resolver context
credentials_info = await resolver.get_available_credentials()
grant_resolver_context = GrantResolverContext(
client_db_id=resolver.client_id,
grant_ids=[c["grant_id"] for c in credentials_info],
)
except GrantValidationError as e:
raise HTTPException(
status_code=403,
detail=f"Grant validation failed: {str(e)}",
)
try:
# Build execution context with grant resolver info
execution_context = ExecutionContext(
grant_resolver_context=grant_resolver_context,
)
# Execute the agent with grant resolver context
graph_exec = await add_graph_execution(
graph_id=agent_id,
user_id=token.user_id,
inputs=request.inputs,
graph_version=graph.version,
execution_context=execution_context,
)
# Log the execution for audit
logger.info(
f"External execution started: agent={agent_id}, "
f"execution={graph_exec.id}, client={token.client_id}, "
f"user={token.user_id}"
)
# Register webhook if provided
if request.webhook_url:
# Get client to check webhook domains
client = await prisma.oauthclient.find_unique(
where={"clientId": token.client_id}
)
if client:
if not validate_webhook_url(request.webhook_url, client.webhookDomains):
raise HTTPException(
status_code=400,
detail="Webhook URL not in allowed domains for this client",
)
# Store webhook registration with client's webhook secret
await prisma.executionwebhook.create(
data={ # type: ignore[typeddict-item]
"executionId": graph_exec.id,
"webhookUrl": request.webhook_url,
"clientId": client.id,
"userId": token.user_id,
"secret": client.webhookSecret,
}
)
logger.info(
f"Registered webhook for execution {graph_exec.id}: {request.webhook_url}"
)
return ExecuteAgentResponse(
execution_id=graph_exec.id,
status="queued",
message="Agent execution has been queued",
)
except ValueError as e:
# Client error - invalid input or configuration
logger.warning(
f"Invalid execution request: agent={agent_id}, "
f"client={token.client_id}, error={str(e)}"
)
raise HTTPException(
status_code=400,
detail=f"Invalid request: {str(e)}",
)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception:
# Server error - log full exception but don't expose details to client
logger.exception(
f"Unexpected error starting execution: agent={agent_id}, "
f"client={token.client_id}"
)
raise HTTPException(
status_code=500,
detail="An internal error occurred while starting execution",
)
@execution_router.get(
"/{execution_id}",
response_model=ExecutionStatusResponse,
)
async def get_execution_status(
execution_id: str,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> ExecutionStatusResponse:
"""
Get the status of an agent execution.
Returns current status, outputs (if completed), and any error messages.
"""
graph_exec = await execution_db.get_graph_execution(
user_id=token.user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not graph_exec:
raise HTTPException(
status_code=404,
detail=f"Execution {execution_id} not found",
)
# Build response
outputs = None
error = None
if graph_exec.status == AgentExecutionStatus.COMPLETED:
outputs = graph_exec.outputs
elif graph_exec.status == AgentExecutionStatus.FAILED:
# Get error from execution stats
# Note: Currently no standard error field in stats, but could be added
error = "Execution failed"
return ExecutionStatusResponse(
execution_id=execution_id,
status=graph_exec.status.value,
started_at=graph_exec.started_at,
completed_at=graph_exec.ended_at,
outputs=outputs,
error=error,
)
@execution_router.post("/{execution_id}/cancel")
async def cancel_execution(
execution_id: str,
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
) -> dict:
"""
Cancel a running agent execution.
Only executions in QUEUED or RUNNING status can be cancelled.
"""
graph_exec = await execution_db.get_graph_execution(
user_id=token.user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not graph_exec:
raise HTTPException(
status_code=404,
detail=f"Execution {execution_id} not found",
)
# Check if execution can be cancelled
if graph_exec.status not in [
AgentExecutionStatus.QUEUED,
AgentExecutionStatus.RUNNING,
]:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel execution with status {graph_exec.status.value}",
)
# Update execution status to TERMINATED
# Note: This is a simplified implementation. A full implementation would
# need to signal the executor to stop processing.
await prisma.agentgraphexecution.update(
where={"id": execution_id},
data={"executionStatus": AgentExecutionStatus.TERMINATED},
)
logger.info(
f"Execution terminated: execution={execution_id}, "
f"client={token.client_id}, user={token.user_id}"
)
return {"message": "Execution terminated", "execution_id": execution_id}

View File

@@ -1,207 +0,0 @@
"""
Credential Grants endpoints for external OAuth clients.
Allows external applications to:
- List their credential grants (metadata only, NOT credential values)
- Get grant details
- Delete credentials via grants (if permitted)
Credentials are NEVER returned to external applications.
"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, HTTPException, Security
from pydantic import BaseModel
from backend.data import credential_grants as grants_db
from backend.data.db import prisma
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
grants_router = APIRouter(prefix="/grants", tags=["grants"])
# ================================================================
# Response Models
# ================================================================
class GrantSummary(BaseModel):
"""Summary of a credential grant (returned in list endpoints)."""
id: str
provider: str
granted_scopes: list[str]
permissions: list[str]
created_at: datetime
last_used_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
class GrantDetail(BaseModel):
"""Detailed grant information."""
id: str
provider: str
credential_id: str
granted_scopes: list[str]
permissions: list[str]
created_at: datetime
updated_at: datetime
last_used_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
revoked_at: Optional[datetime] = None
# ================================================================
# Endpoints
# ================================================================
@grants_router.get("/", response_model=list[GrantSummary])
async def list_grants(
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
) -> list[GrantSummary]:
"""
List all active credential grants for this client and user.
Returns grant metadata but NOT credential values.
Credentials are never exposed to external applications.
"""
# Get the OAuth client's database ID from the public client_id
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
grants = await grants_db.get_grants_for_user_client(
user_id=token.user_id,
client_id=client.id,
include_revoked=False,
include_expired=False,
)
return [
GrantSummary(
id=grant.id,
provider=grant.provider,
granted_scopes=grant.grantedScopes,
permissions=[p.value for p in grant.permissions],
created_at=grant.createdAt,
last_used_at=grant.lastUsedAt,
expires_at=grant.expiresAt,
)
for grant in grants
]
@grants_router.get("/{grant_id}", response_model=GrantDetail)
async def get_grant(
grant_id: str,
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
) -> GrantDetail:
"""
Get detailed information about a specific grant.
Returns grant metadata including scopes and permissions.
Does NOT return the credential value.
"""
# Get the OAuth client's database ID
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=token.user_id,
client_id=client.id,
)
if not grant:
raise HTTPException(status_code=404, detail="Grant not found")
# Check if expired
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
raise HTTPException(status_code=404, detail="Grant has expired")
# Check if revoked
if grant.revokedAt:
raise HTTPException(status_code=404, detail="Grant has been revoked")
return GrantDetail(
id=grant.id,
provider=grant.provider,
credential_id=grant.credentialId,
granted_scopes=grant.grantedScopes,
permissions=[p.value for p in grant.permissions],
created_at=grant.createdAt,
updated_at=grant.updatedAt,
last_used_at=grant.lastUsedAt,
expires_at=grant.expiresAt,
revoked_at=grant.revokedAt,
)
@grants_router.delete("/{grant_id}/credential")
async def delete_credential_via_grant(
grant_id: str,
token: OAuthTokenInfo = Security(require_scope("integrations:delete")),
) -> dict:
"""
Delete the underlying credential associated with a grant.
This requires the grant to have the DELETE permission.
Deleting the credential also invalidates all grants for that credential.
"""
from prisma.enums import CredentialGrantPermission
# Get the OAuth client's database ID
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
if not client:
raise HTTPException(status_code=400, detail="Invalid client")
# Get the grant
grant = await grants_db.get_credential_grant(
grant_id=grant_id,
user_id=token.user_id,
client_id=client.id,
)
if not grant:
raise HTTPException(status_code=404, detail="Grant not found")
# Check if grant is valid
if grant.revokedAt:
raise HTTPException(status_code=400, detail="Grant has been revoked")
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
raise HTTPException(status_code=400, detail="Grant has expired")
# Check DELETE permission
if CredentialGrantPermission.DELETE not in grant.permissions:
raise HTTPException(
status_code=403,
detail="Grant does not have DELETE permission for this credential",
)
# Delete the credential using the credentials store
try:
creds_store = IntegrationCredentialsStore()
await creds_store.delete_creds_by_id(
user_id=token.user_id,
credentials_id=grant.credentialId,
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to delete credential: {str(e)}",
)
# Revoke all grants for this credential
await grants_db.revoke_grants_for_credential(
user_id=token.user_id,
credential_id=grant.credentialId,
)
return {"message": "Credential deleted successfully"}

View File

@@ -0,0 +1,650 @@
"""
External API endpoints for integrations and credentials.
This module provides endpoints for external applications (like Autopilot) to:
- Initiate OAuth flows with custom callback URLs
- Complete OAuth flows by exchanging authorization codes
- Create API key, user/password, and host-scoped credentials
- List and manage user credentials
"""
import logging
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
from urllib.parse import urlparse
from fastapi import APIRouter, Body, HTTPException, Path, Security, status
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field, SecretStr
from backend.data.api_key import APIKeyInfo
from backend.data.model import (
APIKeyCredentials,
Credentials,
CredentialsType,
HostScopedCredentials,
OAuth2Credentials,
UserPasswordCredentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.server.external.middleware import require_permission
from backend.server.integrations.models import get_all_provider_names
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
logger = logging.getLogger(__name__)
settings = Settings()
creds_manager = IntegrationCredentialsManager()
integrations_router = APIRouter(prefix="/integrations", tags=["integrations"])
# ==================== Request/Response Models ==================== #
class OAuthInitiateRequest(BaseModel):
"""Request model for initiating an OAuth flow."""
callback_url: str = Field(
..., description="The external app's callback URL for OAuth redirect"
)
scopes: list[str] = Field(
default_factory=list, description="OAuth scopes to request"
)
state_metadata: dict[str, Any] = Field(
default_factory=dict,
description="Arbitrary metadata to echo back on completion",
)
class OAuthInitiateResponse(BaseModel):
"""Response model for OAuth initiation."""
login_url: str = Field(..., description="URL to redirect user for OAuth consent")
state_token: str = Field(..., description="State token for CSRF protection")
expires_at: int = Field(
..., description="Unix timestamp when the state token expires"
)
class OAuthCompleteRequest(BaseModel):
"""Request model for completing an OAuth flow."""
code: str = Field(..., description="Authorization code from OAuth provider")
state_token: str = Field(..., description="State token from initiate request")
class OAuthCompleteResponse(BaseModel):
"""Response model for OAuth completion."""
credentials_id: str = Field(..., description="ID of the stored credentials")
provider: str = Field(..., description="Provider name")
type: str = Field(..., description="Credential type (oauth2)")
title: Optional[str] = Field(None, description="Credential title")
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
username: Optional[str] = Field(None, description="Username from provider")
state_metadata: dict[str, Any] = Field(
default_factory=dict, description="Echoed metadata from initiate request"
)
class CredentialSummary(BaseModel):
"""Summary of a credential without sensitive data."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
scopes: Optional[list[str]] = None
username: Optional[str] = None
host: Optional[str] = None
class ProviderInfo(BaseModel):
"""Information about an integration provider."""
name: str
supports_oauth: bool = False
supports_api_key: bool = False
supports_user_password: bool = False
supports_host_scoped: bool = False
default_scopes: list[str] = Field(default_factory=list)
# ==================== Credential Creation Models ==================== #
class CreateAPIKeyCredentialRequest(BaseModel):
"""Request model for creating API key credentials."""
type: Literal["api_key"] = "api_key"
api_key: str = Field(..., description="The API key")
title: str = Field(..., description="A name for this credential")
expires_at: Optional[int] = Field(
None, description="Unix timestamp when the API key expires"
)
class CreateUserPasswordCredentialRequest(BaseModel):
"""Request model for creating username/password credentials."""
type: Literal["user_password"] = "user_password"
username: str = Field(..., description="Username")
password: str = Field(..., description="Password")
title: str = Field(..., description="A name for this credential")
class CreateHostScopedCredentialRequest(BaseModel):
"""Request model for creating host-scoped credentials."""
type: Literal["host_scoped"] = "host_scoped"
host: str = Field(..., description="Host/domain pattern to match")
headers: dict[str, str] = Field(..., description="Headers to include in requests")
title: str = Field(..., description="A name for this credential")
# Union type for credential creation
CreateCredentialRequest = Annotated[
CreateAPIKeyCredentialRequest
| CreateUserPasswordCredentialRequest
| CreateHostScopedCredentialRequest,
Field(discriminator="type"),
]
class CreateCredentialResponse(BaseModel):
"""Response model for credential creation."""
id: str
provider: str
type: CredentialsType
title: Optional[str] = None
# ==================== Helper Functions ==================== #
def validate_callback_url(callback_url: str) -> bool:
"""Validate that the callback URL is from an allowed origin."""
allowed_origins = settings.config.external_oauth_callback_origins
try:
parsed = urlparse(callback_url)
callback_origin = f"{parsed.scheme}://{parsed.netloc}"
for allowed in allowed_origins:
# Simple origin matching
if callback_origin == allowed:
return True
# Allow localhost with any port in development (proper hostname check)
if parsed.hostname == "localhost":
for allowed in allowed_origins:
allowed_parsed = urlparse(allowed)
if allowed_parsed.hostname == "localhost":
return True
return False
except Exception:
return False
def _get_oauth_handler_for_external(
provider_name: str, redirect_uri: str
) -> "BaseOAuthHandler":
"""Get an OAuth handler configured with an external redirect URI."""
# Ensure blocks are loaded so SDK providers are available
try:
from backend.blocks import load_all_blocks
load_all_blocks()
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name}' does not support OAuth",
)
# Check if this provider has custom OAuth credentials
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_name)
if oauth_credentials and not oauth_credentials.use_secrets:
import os
client_id = (
os.getenv(oauth_credentials.client_id_env_var)
if oauth_credentials.client_id_env_var
else None
)
client_secret = (
os.getenv(oauth_credentials.client_secret_env_var)
if oauth_credentials.client_secret_env_var
else None
)
else:
client_id = getattr(settings.secrets, f"{provider_name}_client_id", None)
client_secret = getattr(
settings.secrets, f"{provider_name}_client_secret", None
)
if not (client_id and client_secret):
logger.error(f"Attempt to use unconfigured {provider_name} OAuth integration")
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail={
"message": f"Integration with provider '{provider_name}' is not configured.",
"hint": "Set client ID and secret in the application's deployment environment",
},
)
handler_class = HANDLERS_BY_NAME[provider_name]
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
)
# ==================== Endpoints ==================== #
@integrations_router.get("/providers", response_model=list[ProviderInfo])
async def list_providers(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[ProviderInfo]:
"""
List all available integration providers.
Returns a list of all providers with their supported credential types.
Most providers support API key credentials, and some also support OAuth.
"""
# Ensure blocks are loaded
try:
from backend.blocks import load_all_blocks
load_all_blocks()
except Exception as e:
logger.warning(f"Failed to load blocks: {e}")
from backend.sdk.registry import AutoRegistry
providers = []
for name in get_all_provider_names():
supports_oauth = name in HANDLERS_BY_NAME
handler_class = HANDLERS_BY_NAME.get(name)
default_scopes = (
getattr(handler_class, "DEFAULT_SCOPES", []) if handler_class else []
)
# Check if provider has specific auth types from SDK registration
sdk_provider = AutoRegistry.get_provider(name)
if sdk_provider and sdk_provider.supported_auth_types:
supports_api_key = "api_key" in sdk_provider.supported_auth_types
supports_user_password = (
"user_password" in sdk_provider.supported_auth_types
)
supports_host_scoped = "host_scoped" in sdk_provider.supported_auth_types
else:
# Fallback for legacy providers
supports_api_key = True # All providers can accept API keys
supports_user_password = name in ("smtp",)
supports_host_scoped = name == "http"
providers.append(
ProviderInfo(
name=name,
supports_oauth=supports_oauth,
supports_api_key=supports_api_key,
supports_user_password=supports_user_password,
supports_host_scoped=supports_host_scoped,
default_scopes=default_scopes,
)
)
return providers
@integrations_router.post(
"/{provider}/oauth/initiate",
response_model=OAuthInitiateResponse,
summary="Initiate OAuth flow",
)
async def initiate_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthInitiateRequest,
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthInitiateResponse:
"""
Initiate an OAuth flow for an external application.
This endpoint allows external apps to start an OAuth flow with a custom
callback URL. The callback URL must be from an allowed origin configured
in the platform settings.
Returns a login URL to redirect the user to, along with a state token
for CSRF protection.
"""
# Validate callback URL
if not validate_callback_url(request.callback_url):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
)
# Validate provider
try:
provider_name = ProviderName(provider)
except ValueError:
# Check if it's a dynamically registered provider
if provider not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider}' not found",
)
provider_name = provider
# Get OAuth handler with external callback URL
handler = _get_oauth_handler_for_external(
provider if isinstance(provider_name, str) else provider_name.value,
request.callback_url,
)
# Store state token with external flow metadata
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id=api_key.user_id,
provider=provider if isinstance(provider_name, str) else provider_name.value,
scopes=request.scopes,
callback_url=request.callback_url,
state_metadata=request.state_metadata,
initiated_by_api_key_id=api_key.id,
)
# Build login URL
login_url = handler.get_login_url(
request.scopes, state_token, code_challenge=code_challenge
)
# Calculate expiration (10 minutes from now)
from datetime import datetime, timedelta, timezone
expires_at = int((datetime.now(timezone.utc) + timedelta(minutes=10)).timestamp())
return OAuthInitiateResponse(
login_url=login_url,
state_token=state_token,
expires_at=expires_at,
)
@integrations_router.post(
"/{provider}/oauth/complete",
response_model=OAuthCompleteResponse,
summary="Complete OAuth flow",
)
async def complete_oauth(
provider: Annotated[str, Path(title="The OAuth provider")],
request: OAuthCompleteRequest,
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> OAuthCompleteResponse:
"""
Complete an OAuth flow by exchanging the authorization code for tokens.
This endpoint should be called after the user has authorized the application
and been redirected back to the external app's callback URL with an
authorization code.
"""
# Verify state token
valid_state = await creds_manager.store.verify_state_token(
api_key.user_id, request.state_token, provider
)
if not valid_state:
logger.warning(f"Invalid or expired state token for provider {provider}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired state token",
)
# Verify this is an external flow (callback_url must be set)
if not valid_state.callback_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="State token was not created for external OAuth flow",
)
# Get OAuth handler with the original callback URL
handler = _get_oauth_handler_for_external(provider, valid_state.callback_url)
try:
scopes = valid_state.scopes
scopes = handler.handle_default_scopes(scopes)
credentials = await handler.exchange_code_for_tokens(
request.code, scopes, valid_state.code_verifier
)
# Handle Linear's space-separated scopes
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
credentials.scopes = credentials.scopes[0].split(" ")
# Check scope mismatch
if not set(scopes).issubset(set(credentials.scopes)):
logger.warning(
f"Granted scopes {credentials.scopes} for provider {provider} "
f"do not include all requested scopes {scopes}"
)
except Exception as e:
logger.error(f"OAuth2 Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
)
# Store credentials
await creds_manager.create(api_key.user_id, credentials)
logger.info(f"Successfully completed external OAuth for provider {provider}")
return OAuthCompleteResponse(
credentials_id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
state_metadata=valid_state.state_metadata,
)
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
async def list_credentials(
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
"""
List all credentials for the authenticated user.
Returns metadata about each credential without exposing sensitive tokens.
"""
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@integrations_router.get(
"/{provider}/credentials", response_model=list[CredentialSummary]
)
async def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.READ_INTEGRATIONS)
),
) -> list[CredentialSummary]:
"""
List credentials for a specific provider.
"""
credentials = await creds_manager.store.get_creds_by_provider(
api_key.user_id, provider
)
return [
CredentialSummary(
id=cred.id,
provider=cred.provider,
type=cred.type,
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@integrations_router.post(
"/{provider}/credentials",
response_model=CreateCredentialResponse,
status_code=status.HTTP_201_CREATED,
summary="Create credentials",
)
async def create_credential(
provider: Annotated[str, Path(title="The provider to create credentials for")],
request: Union[
CreateAPIKeyCredentialRequest,
CreateUserPasswordCredentialRequest,
CreateHostScopedCredentialRequest,
] = Body(..., discriminator="type"),
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
),
) -> CreateCredentialResponse:
"""
Create non-OAuth credentials for a provider.
Supports creating:
- API key credentials (type: "api_key")
- Username/password credentials (type: "user_password")
- Host-scoped credentials (type: "host_scoped")
For OAuth credentials, use the OAuth initiate/complete flow instead.
"""
# Validate provider exists
all_providers = get_all_provider_names()
if provider not in all_providers:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider}' not found",
)
# Create the appropriate credential type
credentials: Credentials
if request.type == "api_key":
credentials = APIKeyCredentials(
provider=provider,
api_key=SecretStr(request.api_key),
title=request.title,
expires_at=request.expires_at,
)
elif request.type == "user_password":
credentials = UserPasswordCredentials(
provider=provider,
username=SecretStr(request.username),
password=SecretStr(request.password),
title=request.title,
)
elif request.type == "host_scoped":
# Convert string headers to SecretStr
secret_headers = {k: SecretStr(v) for k, v in request.headers.items()}
credentials = HostScopedCredentials(
provider=provider,
host=request.host,
headers=secret_headers,
title=request.title,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported credential type: {request.type}",
)
# Store credentials
try:
await creds_manager.create(api_key.user_id, credentials)
except Exception as e:
logger.error(f"Failed to store credentials: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to store credentials: {str(e)}",
)
logger.info(f"Created {request.type} credentials for provider {provider}")
return CreateCredentialResponse(
id=credentials.id,
provider=provider,
type=credentials.type,
title=credentials.title,
)
class DeleteCredentialResponse(BaseModel):
"""Response model for deleting a credential."""
deleted: bool = Field(..., description="Whether the credential was deleted")
credentials_id: str = Field(..., description="ID of the deleted credential")
@integrations_router.delete(
"/{provider}/credentials/{cred_id}",
response_model=DeleteCredentialResponse,
)
async def delete_credential(
provider: Annotated[str, Path(title="The provider")],
cred_id: Annotated[str, Path(title="The credential ID to delete")],
api_key: APIKeyInfo = Security(
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
),
) -> DeleteCredentialResponse:
"""
Delete a credential.
Note: This does not revoke the tokens with the provider. For full cleanup,
use the main API's delete endpoint which handles webhook cleanup and
token revocation.
"""
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
if not creds:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
)
if creds.provider != provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Credentials do not match the specified provider",
)
await creds_manager.delete(api_key.user_id, cred_id)
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)

View File

@@ -0,0 +1,148 @@
"""External API routes for chat tools - stateless HTTP endpoints.
Note: These endpoints use ephemeral sessions that are not persisted to Redis.
As a result, session-based rate limiting (max_agent_runs, max_agent_schedules)
is not enforced for external API calls. Each request creates a fresh session
with zeroed counters. Rate limiting for external API consumers should be
handled separately (e.g., via API key quotas).
"""
import logging
from typing import Any
from fastapi import APIRouter, Security
from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.data.api_key import APIKeyInfo
from backend.server.external.middleware import require_permission
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
from backend.server.v2.chat.tools.models import ToolResponseBase
logger = logging.getLogger(__name__)
tools_router = APIRouter(prefix="/tools", tags=["tools"])
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
# Request models
class FindAgentRequest(BaseModel):
query: str = Field(..., description="Search query for finding agents")
class RunAgentRequest(BaseModel):
"""Request to run or schedule an agent.
The tool automatically handles the setup flow:
- First call returns available inputs so user can decide what values to use
- Returns missing credentials if user needs to configure them
- Executes when inputs are provided OR use_defaults=true
- Schedules execution if schedule_name and cron are provided
"""
username_agent_slug: str = Field(
...,
description="The marketplace agent slug (e.g., 'username/agent-name')",
)
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Dictionary of input values for the agent",
)
use_defaults: bool = Field(
default=False,
description="Set to true to run with default values (user must confirm)",
)
schedule_name: str | None = Field(
None,
description="Name for scheduled execution (triggers scheduling mode)",
)
cron: str | None = Field(
None,
description="Cron expression (5 fields: minute hour day month weekday)",
)
timezone: str = Field(
default="UTC",
description="IANA timezone (e.g., 'America/New_York', 'UTC')",
)
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
"""Create an ephemeral session for stateless API requests."""
return ChatSession.new(user_id)
@tools_router.post(
path="/find-agent",
)
async def find_agent(
request: FindAgentRequest,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Search for agents in the marketplace based on capabilities and user needs.
Args:
request: Search query for finding agents
Returns:
List of matching agents or no results response
"""
session = _create_ephemeral_session(api_key.user_id)
result = await find_agent_tool._execute(
user_id=api_key.user_id,
session=session,
query=request.query,
)
return _response_to_dict(result)
@tools_router.post(
path="/run-agent",
)
async def run_agent(
request: RunAgentRequest,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
) -> dict[str, Any]:
"""
Run or schedule an agent from the marketplace.
The endpoint automatically handles the setup flow:
- Returns missing inputs if required fields are not provided
- Returns missing credentials if user needs to configure them
- Executes immediately if all requirements are met
- Schedules execution if schedule_name and cron are provided
For scheduled execution:
- Cron format: "minute hour day month weekday"
- Examples: "0 9 * * 1-5" (9am weekdays), "0 0 * * *" (daily at midnight)
- Timezone: Use IANA timezone names like "America/New_York"
Args:
request: Agent slug, inputs, and optional schedule config
Returns:
- setup_requirements: If inputs or credentials are missing
- execution_started: If agent was run or scheduled successfully
- error: If something went wrong
"""
session = _create_ephemeral_session(api_key.user_id)
result = await run_agent_tool._execute(
user_id=api_key.user_id,
session=session,
username_agent_slug=request.username_agent_slug,
inputs=request.inputs,
use_defaults=request.use_defaults,
schedule_name=request.schedule_name or "",
cron=request.cron or "",
timezone=request.timezone,
)
return _response_to_dict(result)
def _response_to_dict(result: ToolResponseBase) -> dict[str, Any]:
"""Convert a tool response to a dictionary for JSON serialization."""
return result.model_dump()

View File

@@ -0,0 +1,295 @@
import logging
import urllib.parse
from collections import defaultdict
from typing import Annotated, Any, Literal, Optional, Sequence
from fastapi import APIRouter, Body, HTTPException, Security
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
import backend.data.block
import backend.server.v2.store.cache as store_cache
import backend.server.v2.store.model as store_model
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import APIKeyInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.server.external.middleware import require_permission
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
v1_router = APIRouter()
class NodeOutput(TypedDict):
key: str
value: Any
class ExecutionNode(TypedDict):
node_id: str
input: Any
output: dict[str, Any]
class ExecutionNodeOutput(TypedDict):
node_id: str
outputs: list[NodeOutput]
class GraphExecutionResult(TypedDict):
execution_id: str
status: str
nodes: list[ExecutionNode]
output: Optional[list[dict[str, str]]]
@v1_router.get(
path="/blocks",
tags=["blocks"],
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
return [b.to_dict() for b in blocks if not b.disabled]
@v1_router.post(
path="/blocks/{block_id}/execute",
tags=["blocks"],
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
)
async def execute_graph_block(
block_id: str,
data: BlockInput,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
output = defaultdict(list)
async for name, data in obj.execute(data):
output[name].append(data)
return output
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = await add_graph_execution(
graph_id=graph_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@v1_router.get(
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
tags=["graphs"],
)
async def get_graph_execution_results(
graph_id: str,
graph_exec_id: str,
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
) -> GraphExecutionResult:
graph_exec = await execution_db.get_graph_execution(
user_id=api_key.user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
if not graph_exec:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)
if not await graph_db.get_graph(
graph_id=graph_exec.graph_id,
version=graph_exec.graph_version,
user_id=api_key.user_id,
):
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return GraphExecutionResult(
execution_id=graph_exec_id,
status=graph_exec.status.value,
nodes=[
ExecutionNode(
node_id=node_exec.node_id,
input=node_exec.input_data.get("value", node_exec.input_data),
output={k: v for k, v in node_exec.output_data.items()},
)
for node_exec in graph_exec.node_executions
],
output=(
[
{name: value}
for name, values in graph_exec.outputs.items()
for value in values
]
if graph_exec.status == AgentExecutionStatus.COMPLETED
else None
),
)
##############################################
############### Store Endpoints ##############
##############################################
@v1_router.get(
path="/store/agents",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentsResponse,
)
async def get_store_agents(
featured: bool = False,
creator: str | None = None,
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
search_query: str | None = None,
category: str | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.StoreAgentsResponse:
"""
Get a paginated list of agents from the store with optional filtering and sorting.
Args:
featured: Filter to only show featured agents
creator: Filter agents by creator username
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at"
search_query: Search agents by name, subheading and description
category: Filter agents by category
page: Page number for pagination (default 1)
page_size: Number of agents per page (default 20)
Returns:
StoreAgentsResponse: Paginated list of agents matching the filters
"""
if page < 1:
raise HTTPException(status_code=422, detail="Page must be greater than 0")
if page_size < 1:
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@v1_router.get(
path="/store/agents/{username}/{agent_name}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.StoreAgentDetails,
)
async def get_store_agent(
username: str,
agent_name: str,
) -> store_model.StoreAgentDetails:
"""
Get details of a specific store agent by username and agent name.
Args:
username: Creator's username
agent_name: Name/slug of the agent
Returns:
StoreAgentDetails: Detailed information about the agent
"""
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return agent
@v1_router.get(
path="/store/creators",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorsResponse,
)
async def get_store_creators(
featured: bool = False,
search_query: str | None = None,
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
page: int = 1,
page_size: int = 20,
) -> store_model.CreatorsResponse:
"""
Get a paginated list of store creators with optional filtering and sorting.
Args:
featured: Filter to only show featured creators
search_query: Search creators by profile description
sorted_by: Sort by "agent_rating", "agent_runs", or "num_agents"
page: Page number for pagination (default 1)
page_size: Number of creators per page (default 20)
Returns:
CreatorsResponse: Paginated list of creators matching the filters
"""
if page < 1:
raise HTTPException(status_code=422, detail="Page must be greater than 0")
if page_size < 1:
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return creators
@v1_router.get(
path="/store/creators/{username}",
tags=["store"],
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
response_model=store_model.CreatorDetails,
)
async def get_store_creator(
username: str,
) -> store_model.CreatorDetails:
"""
Get details of a specific store creator by username.
Args:
username: Creator's username
Returns:
CreatorDetails: Detailed information about the creator
"""
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator

View File

@@ -1,471 +0,0 @@
"""
Security utilities for the integration connect popup flow.
Handles state management, nonce validation, and origin verification
for the OAuth-style popup flow when connecting integrations.
"""
import hashlib
import logging
import secrets
from datetime import datetime, timezone
from typing import Any, Optional
from urllib.parse import urlparse
from prisma.models import OAuthClient
from pydantic import BaseModel
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
# State expiration time
STATE_EXPIRATION_SECONDS = 600 # 10 minutes
NONCE_EXPIRATION_SECONDS = 3600 # 1 hour (nonces valid for longer to prevent races)
LOGIN_STATE_EXPIRATION_SECONDS = 600 # 10 minutes for login redirect flow
class ConnectState(BaseModel):
"""Pydantic model for connect state stored in Redis."""
user_id: str
client_id: str
provider: str
requested_scopes: list[str]
redirect_origin: str
nonce: str
credential_id: Optional[str] = None
created_at: str
expires_at: str
class ConnectContinuationState(BaseModel):
"""
State for continuing the connect flow after OAuth completes.
When a user chooses to "connect new" during the connect flow,
we store this state so we can complete the grant creation after
the OAuth callback.
"""
user_id: str
client_id: str # Public client ID
client_db_id: str # Database UUID of the OAuth client
provider: str
requested_scopes: list[str] # Integration scopes (e.g., "google:gmail.readonly")
redirect_origin: str
nonce: str
created_at: str
class ConnectLoginState(BaseModel):
"""
State for connect flow when user needs to log in first.
When an unauthenticated user tries to access /connect/{provider},
we store the connect parameters and redirect to login. After login,
the user is redirected back to complete the connect flow.
"""
client_id: str
provider: str
requested_scopes: list[str]
redirect_origin: str
nonce: str
created_at: str
expires_at: str
# Continuation state expiration (same as regular state)
CONTINUATION_EXPIRATION_SECONDS = 600 # 10 minutes
async def store_connect_continuation(
user_id: str,
client_id: str,
client_db_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
) -> str:
"""
Store continuation state for completing connect flow after OAuth.
Args:
user_id: User initiating the connection
client_id: Public OAuth client ID
client_db_id: Database UUID of the OAuth client
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
Returns:
Continuation token to be stored in OAuth state metadata
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
state = ConnectContinuationState(
user_id=user_id,
client_id=client_id,
client_db_id=client_db_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
created_at=now.isoformat(),
)
redis = await get_redis_async()
key = f"connect_continuation:{token}"
await redis.setex(key, CONTINUATION_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect continuation state for token {token[:8]}...")
return token
async def get_connect_continuation(token: str) -> Optional[ConnectContinuationState]:
"""
Get continuation state without consuming it.
Args:
token: Continuation token
Returns:
ConnectContinuationState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_continuation:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectContinuationState.model_validate_json(data)
async def consume_connect_continuation(
token: str,
) -> Optional[ConnectContinuationState]:
"""
Get and consume (delete) continuation state.
This ensures the token can only be used once.
Args:
token: Continuation token
Returns:
ConnectContinuationState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_continuation:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectContinuationState.model_validate_json(data)
logger.debug(f"Consumed connect continuation state for token {token[:8]}...")
return state
def generate_connect_token() -> str:
"""Generate a secure random token for connect state."""
return secrets.token_urlsafe(32)
async def store_connect_state(
user_id: str,
client_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
credential_id: Optional[str] = None,
) -> str:
"""
Store connect state in Redis and return a state token.
Args:
user_id: User initiating the connection
client_id: OAuth client ID (public identifier)
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
credential_id: Optional existing credential to grant access to
Returns:
State token to be used in the connect flow
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
expires_at = now.timestamp() + STATE_EXPIRATION_SECONDS
state = ConnectState(
user_id=user_id,
client_id=client_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
credential_id=credential_id,
created_at=now.isoformat(),
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
)
redis = await get_redis_async()
key = f"connect_state:{token}"
await redis.setex(key, STATE_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect state for token {token[:8]}...")
return token
async def get_connect_state(token: str) -> Optional[ConnectState]:
"""
Get connect state without consuming it.
Args:
token: State token
Returns:
ConnectState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_state:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectState.model_validate_json(data)
async def consume_connect_state(token: str) -> Optional[ConnectState]:
"""
Get and consume (delete) connect state.
This ensures the token can only be used once.
Args:
token: State token
Returns:
ConnectState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_state:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectState.model_validate_json(data)
logger.debug(f"Consumed connect state for token {token[:8]}...")
return state
async def store_connect_login_state(
client_id: str,
provider: str,
requested_scopes: list[str],
redirect_origin: str,
nonce: str,
) -> str:
"""
Store connect parameters for unauthenticated users.
When a user isn't logged in, we store the connect params and redirect
to login. After login, the frontend calls /connect/resume with the token.
Args:
client_id: OAuth client ID
provider: Integration provider name
requested_scopes: Requested integration scopes
redirect_origin: Origin to send postMessage to
nonce: Client-provided nonce for replay protection
Returns:
Login state token to be used after login completes
"""
token = generate_connect_token()
now = datetime.now(timezone.utc)
expires_at = now.timestamp() + LOGIN_STATE_EXPIRATION_SECONDS
state = ConnectLoginState(
client_id=client_id,
provider=provider,
requested_scopes=requested_scopes,
redirect_origin=redirect_origin,
nonce=nonce,
created_at=now.isoformat(),
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
)
redis = await get_redis_async()
key = f"connect_login_state:{token}"
await redis.setex(key, LOGIN_STATE_EXPIRATION_SECONDS, state.model_dump_json())
logger.debug(f"Stored connect login state for token {token[:8]}...")
return token
async def get_connect_login_state(token: str) -> Optional[ConnectLoginState]:
"""
Get connect login state without consuming it.
Args:
token: Login state token
Returns:
ConnectLoginState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_login_state:{token}"
data = await redis.get(key)
if not data:
return None
return ConnectLoginState.model_validate_json(data)
async def consume_connect_login_state(token: str) -> Optional[ConnectLoginState]:
"""
Get and consume (delete) connect login state.
This ensures the token can only be used once.
Args:
token: Login state token
Returns:
ConnectLoginState or None if not found/expired
"""
redis = await get_redis_async()
key = f"connect_login_state:{token}"
# Atomic get-and-delete to prevent race conditions
data = await redis.getdel(key)
if not data:
return None
state = ConnectLoginState.model_validate_json(data)
logger.debug(f"Consumed connect login state for token {token[:8]}...")
return state
async def validate_nonce(client_id: str, nonce: str) -> bool:
"""
Validate that a nonce hasn't been used before (replay protection).
Uses atomic SET NX EX for check-and-set with automatic TTL expiry.
Args:
client_id: OAuth client ID
nonce: Client-provided nonce
Returns:
True if nonce is valid (not replayed)
"""
redis = await get_redis_async()
# Create a hash of the nonce for storage
nonce_hash = hashlib.sha256(nonce.encode()).hexdigest()
key = f"nonce:{client_id}:{nonce_hash}"
# Atomic set-if-not-exists with expiration (prevents race condition)
was_set = await redis.set(key, "1", nx=True, ex=NONCE_EXPIRATION_SECONDS)
if was_set:
return True
logger.warning(f"Nonce replay detected for client {client_id}")
return False
def validate_redirect_origin(origin: str, client: OAuthClient) -> bool:
"""
Validate that a redirect origin is allowed for the client.
The origin must match one of the client's registered redirect URIs
or webhook domains.
Args:
origin: Origin URL to validate
client: OAuth client to check against
Returns:
True if origin is allowed
"""
from backend.util.url import hostname_matches_any_domain
try:
parsed_origin = urlparse(origin)
origin_host = parsed_origin.netloc.lower()
# Check against redirect URIs
for redirect_uri in client.redirectUris:
parsed_redirect = urlparse(redirect_uri)
if parsed_redirect.netloc.lower() == origin_host:
return True
# Check against webhook domains
if hostname_matches_any_domain(origin_host, client.webhookDomains):
return True
return False
except Exception:
return False
def create_post_message_data(
success: bool,
grant_id: Optional[str] = None,
credential_id: Optional[str] = None,
provider: Optional[str] = None,
error: Optional[str] = None,
error_description: Optional[str] = None,
nonce: Optional[str] = None,
) -> dict[str, Any]:
"""
Create the postMessage data to send back to the opener.
Args:
success: Whether the operation succeeded
grant_id: ID of the created grant (if successful)
credential_id: ID of the credential (if successful)
provider: Provider name
error: Error code (if failed)
error_description: Human-readable error description
nonce: Original nonce for correlation
Returns:
Dictionary to be sent via postMessage
"""
data: dict[str, Any] = {
"type": "autogpt_connect_result",
"success": success,
}
if nonce:
data["nonce"] = nonce
if success:
data["grant_id"] = grant_id
data["credential_id"] = credential_id
data["provider"] = provider
else:
data["error"] = error
data["error_description"] = error_description
return data

View File

@@ -1,20 +0,0 @@
"""
OAuth 2.0 Provider module for AutoGPT Platform.
This module implements AutoGPT as an OAuth 2.0 Authorization Server,
allowing external applications to authenticate users and access
platform resources with user consent.
Key components:
- router.py: OAuth authorization and token endpoints
- discovery_router.py: OIDC discovery endpoints
- client_router.py: OAuth client management
- token_service.py: JWT generation and validation
- service.py: Core OAuth business logic
"""
from backend.server.oauth.client_router import client_router
from backend.server.oauth.discovery_router import discovery_router
from backend.server.oauth.router import oauth_router
__all__ = ["oauth_router", "discovery_router", "client_router"]

View File

@@ -1,367 +0,0 @@
"""
OAuth Client Management endpoints.
Implements self-service client registration and management:
- POST /oauth/clients - Register a new client
- GET /oauth/clients - List owned clients
- GET /oauth/clients/{client_id} - Get client details
- PATCH /oauth/clients/{client_id} - Update client
- DELETE /oauth/clients/{client_id} - Delete client
- POST /oauth/clients/{client_id}/rotate-secret - Rotate client secret
"""
import hashlib
import secrets
from autogpt_libs.auth import get_user_id
from fastapi import APIRouter, HTTPException, Security
from prisma.enums import OAuthClientStatus
from pydantic import BaseModel
from backend.data.db import prisma
from backend.server.oauth.models import (
ClientResponse,
ClientSecretResponse,
OAuthScope,
RegisterClientRequest,
UpdateClientRequest,
)
client_router = APIRouter(prefix="/oauth/clients", tags=["oauth-clients"])
def _generate_client_id() -> str:
"""Generate a unique client ID."""
return f"app_{secrets.token_urlsafe(16)}"
def _generate_client_secret() -> str:
"""Generate a secure client secret."""
return secrets.token_urlsafe(32)
def _generate_webhook_secret() -> str:
"""Generate a secure webhook secret for HMAC signing."""
return secrets.token_urlsafe(32)
def _hash_secret(secret: str, salt: str) -> str:
"""Hash a client secret with salt."""
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
def _client_to_response(client) -> ClientResponse:
"""Convert Prisma client to response model."""
return ClientResponse(
id=client.id,
client_id=client.clientId,
client_type=client.clientType,
name=client.name,
description=client.description,
logo_url=client.logoUrl,
homepage_url=client.homepageUrl,
privacy_policy_url=client.privacyPolicyUrl,
terms_of_service_url=client.termsOfServiceUrl,
redirect_uris=client.redirectUris,
allowed_scopes=client.allowedScopes,
webhook_domains=client.webhookDomains,
status=client.status,
created_at=client.createdAt,
updated_at=client.updatedAt,
)
# Default allowed scopes for new clients
DEFAULT_ALLOWED_SCOPES = [
OAuthScope.OPENID.value,
OAuthScope.PROFILE.value,
OAuthScope.EMAIL.value,
OAuthScope.INTEGRATIONS_LIST.value,
OAuthScope.INTEGRATIONS_CONNECT.value,
OAuthScope.INTEGRATIONS_DELETE.value,
OAuthScope.AGENTS_EXECUTE.value,
]
@client_router.post("/", response_model=ClientSecretResponse)
async def register_client(
request: RegisterClientRequest,
user_id: str = Security(get_user_id),
) -> ClientSecretResponse:
"""
Register a new OAuth client.
The client is immediately active (no admin approval required).
For confidential clients, the client_secret is returned only once.
The webhook_secret is always generated and returned only once.
"""
# Generate client credentials
client_id = _generate_client_id()
client_secret = None
client_secret_hash = None
client_secret_salt = None
if request.client_type == "confidential":
client_secret = _generate_client_secret()
client_secret_salt = secrets.token_urlsafe(16)
client_secret_hash = _hash_secret(client_secret, client_secret_salt)
# Generate webhook secret for HMAC signing
webhook_secret = _generate_webhook_secret()
# Create client
await prisma.oauthclient.create(
data={ # type: ignore[typeddict-item]
"clientId": client_id,
"clientSecretHash": client_secret_hash,
"clientSecretSalt": client_secret_salt,
"clientType": request.client_type,
"name": request.name,
"description": request.description,
"logoUrl": str(request.logo_url) if request.logo_url else None,
"homepageUrl": str(request.homepage_url) if request.homepage_url else None,
"privacyPolicyUrl": (
str(request.privacy_policy_url) if request.privacy_policy_url else None
),
"termsOfServiceUrl": (
str(request.terms_of_service_url)
if request.terms_of_service_url
else None
),
"redirectUris": request.redirect_uris,
"allowedScopes": DEFAULT_ALLOWED_SCOPES,
"webhookDomains": request.webhook_domains,
"webhookSecret": webhook_secret,
"status": OAuthClientStatus.ACTIVE,
"ownerId": user_id,
}
)
return ClientSecretResponse(
client_id=client_id,
client_secret=client_secret or "",
webhook_secret=webhook_secret,
)
@client_router.get("/", response_model=list[ClientResponse])
async def list_clients(
user_id: str = Security(get_user_id),
) -> list[ClientResponse]:
"""List all OAuth clients owned by the current user."""
clients = await prisma.oauthclient.find_many(
where={"ownerId": user_id},
order={"createdAt": "desc"},
)
return [_client_to_response(c) for c in clients]
@client_router.get("/{client_id}", response_model=ClientResponse)
async def get_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Get details of a specific OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
return _client_to_response(client)
@client_router.patch("/{client_id}", response_model=ClientResponse)
async def update_client(
client_id: str,
request: UpdateClientRequest,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Update an OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Build update data
update_data: dict[str, str | list[str] | None] = {}
if request.name is not None:
update_data["name"] = request.name
if request.description is not None:
update_data["description"] = request.description
if request.logo_url is not None:
update_data["logoUrl"] = str(request.logo_url)
if request.homepage_url is not None:
update_data["homepageUrl"] = str(request.homepage_url)
if request.privacy_policy_url is not None:
update_data["privacyPolicyUrl"] = str(request.privacy_policy_url)
if request.terms_of_service_url is not None:
update_data["termsOfServiceUrl"] = str(request.terms_of_service_url)
if request.redirect_uris is not None:
update_data["redirectUris"] = request.redirect_uris
if request.webhook_domains is not None:
update_data["webhookDomains"] = request.webhook_domains
if not update_data:
return _client_to_response(client)
updated = await prisma.oauthclient.update(
where={"id": client.id},
data=update_data, # type: ignore[arg-type]
)
return _client_to_response(updated)
@client_router.delete("/{client_id}")
async def delete_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> dict:
"""
Delete an OAuth client.
This will also revoke all tokens and authorizations for this client.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Delete cascades will handle tokens, codes, and authorizations
await prisma.oauthclient.delete(where={"id": client.id})
return {"status": "deleted", "client_id": client_id}
@client_router.post("/{client_id}/rotate-secret", response_model=ClientSecretResponse)
async def rotate_client_secret(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientSecretResponse:
"""
Rotate the client secret for a confidential client.
The new secret is returned only once. All existing tokens remain valid.
Also rotates the webhook secret for security.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
if client.clientType != "confidential":
raise HTTPException(
status_code=400,
detail="Cannot rotate secret for public clients",
)
# Generate new secrets
new_secret = _generate_client_secret()
new_salt = secrets.token_urlsafe(16)
new_hash = _hash_secret(new_secret, new_salt)
new_webhook_secret = _generate_webhook_secret()
await prisma.oauthclient.update(
where={"id": client.id},
data={
"clientSecretHash": new_hash,
"clientSecretSalt": new_salt,
"webhookSecret": new_webhook_secret,
},
)
return ClientSecretResponse(
client_id=client_id,
client_secret=new_secret,
webhook_secret=new_webhook_secret,
)
class WebhookSecretResponse(BaseModel):
"""Response containing newly generated webhook secret."""
client_id: str
webhook_secret: str
@client_router.post(
"/{client_id}/rotate-webhook-secret", response_model=WebhookSecretResponse
)
async def rotate_webhook_secret(
client_id: str,
user_id: str = Security(get_user_id),
) -> WebhookSecretResponse:
"""
Rotate only the webhook secret for a client.
The new webhook secret is returned only once.
"""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
# Generate new webhook secret
new_webhook_secret = _generate_webhook_secret()
await prisma.oauthclient.update(
where={"id": client.id},
data={"webhookSecret": new_webhook_secret},
)
return WebhookSecretResponse(
client_id=client_id,
webhook_secret=new_webhook_secret,
)
@client_router.post("/{client_id}/suspend")
async def suspend_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Suspend an OAuth client (prevents new authorizations)."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
updated = await prisma.oauthclient.update(
where={"id": client.id},
data={"status": OAuthClientStatus.SUSPENDED},
)
return _client_to_response(updated)
@client_router.post("/{client_id}/activate")
async def activate_client(
client_id: str,
user_id: str = Security(get_user_id),
) -> ClientResponse:
"""Reactivate a suspended OAuth client."""
client = await prisma.oauthclient.find_first(
where={"clientId": client_id, "ownerId": user_id}
)
if not client:
raise HTTPException(status_code=404, detail="Client not found")
updated = await prisma.oauthclient.update(
where={"id": client.id},
data={"status": OAuthClientStatus.ACTIVE},
)
return _client_to_response(updated)

View File

@@ -1,678 +0,0 @@
"""
Server-rendered HTML templates for OAuth consent UI.
These templates are used for the OAuth authorization flow
when the user needs to approve access for an external application.
"""
import html
from typing import Optional
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
def _base_styles() -> str:
"""Common CSS styles for all OAuth pages."""
return """
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
color: #e4e4e7;
}
.container {
background: #27272a;
border-radius: 16px;
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
max-width: 420px;
width: 100%;
padding: 32px;
}
.header {
text-align: center;
margin-bottom: 24px;
}
.logo {
width: 64px;
height: 64px;
border-radius: 12px;
margin-bottom: 16px;
background: #3f3f46;
display: flex;
align-items: center;
justify-content: center;
margin-left: auto;
margin-right: auto;
}
.logo img {
max-width: 48px;
max-height: 48px;
border-radius: 8px;
}
.logo-placeholder {
font-size: 28px;
color: #a1a1aa;
}
h1 {
font-size: 20px;
font-weight: 600;
margin-bottom: 8px;
}
.subtitle {
color: #a1a1aa;
font-size: 14px;
}
.app-name {
color: #22d3ee;
font-weight: 600;
}
.divider {
height: 1px;
background: #3f3f46;
margin: 24px 0;
}
.scopes-section h2 {
font-size: 14px;
font-weight: 500;
color: #a1a1aa;
margin-bottom: 16px;
}
.scope-item {
display: flex;
align-items: flex-start;
gap: 12px;
padding: 12px 0;
border-bottom: 1px solid #3f3f46;
}
.scope-item:last-child {
border-bottom: none;
}
.scope-icon {
width: 20px;
height: 20px;
color: #22d3ee;
flex-shrink: 0;
margin-top: 2px;
}
.scope-text {
font-size: 14px;
line-height: 1.5;
}
.buttons {
display: flex;
gap: 12px;
margin-top: 24px;
}
.btn {
flex: 1;
padding: 12px 24px;
border-radius: 8px;
font-size: 14px;
font-weight: 500;
cursor: pointer;
border: none;
transition: all 0.2s;
}
.btn-cancel {
background: #3f3f46;
color: #e4e4e7;
}
.btn-cancel:hover {
background: #52525b;
}
.btn-allow {
background: #22d3ee;
color: #0f172a;
}
.btn-allow:hover {
background: #06b6d4;
}
.footer {
margin-top: 24px;
text-align: center;
font-size: 12px;
color: #71717a;
}
.footer a {
color: #a1a1aa;
text-decoration: none;
}
.footer a:hover {
text-decoration: underline;
}
.error-container {
text-align: center;
}
.error-icon {
width: 64px;
height: 64px;
margin: 0 auto 16px;
color: #ef4444;
}
.error-title {
color: #ef4444;
font-size: 18px;
font-weight: 600;
margin-bottom: 8px;
}
.error-message {
color: #a1a1aa;
font-size: 14px;
margin-bottom: 24px;
}
.success-icon {
width: 64px;
height: 64px;
margin: 0 auto 16px;
color: #22c55e;
}
.success-title {
color: #22c55e;
}
"""
def _check_icon() -> str:
"""SVG checkmark icon."""
return """
<svg class="scope-icon" viewBox="0 0 20 20" fill="currentColor">
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/>
</svg>
"""
def _error_icon() -> str:
"""SVG error icon."""
return """
<svg class="error-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="12" r="10"/>
<line x1="15" y1="9" x2="9" y2="15"/>
<line x1="9" y1="9" x2="15" y2="15"/>
</svg>
"""
def _success_icon() -> str:
"""SVG success icon."""
return """
<svg class="success-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="12" r="10"/>
<path d="M9 12l2 2 4-4"/>
</svg>
"""
def render_consent_page(
client_name: str,
client_logo: Optional[str],
scopes: list[str],
consent_token: str,
action_url: str,
privacy_policy_url: Optional[str] = None,
terms_url: Optional[str] = None,
) -> str:
"""
Render the OAuth consent page.
Args:
client_name: Name of the requesting application
client_logo: URL to the client's logo (optional)
scopes: List of requested scopes
consent_token: CSRF token for the consent form
action_url: URL to submit the consent form
privacy_policy_url: Client's privacy policy URL (optional)
terms_url: Client's terms of service URL (optional)
Returns:
HTML string for the consent page
"""
# Escape user-provided values to prevent XSS
safe_client_name = html.escape(client_name)
safe_client_logo = html.escape(client_logo) if client_logo else None
# Build logo HTML
if safe_client_logo:
logo_html = f'<img src="{safe_client_logo}" alt="{safe_client_name}">'
else:
logo_html = f'<span class="logo-placeholder">{html.escape(client_name[0].upper())}</span>'
# Build scopes HTML
scopes_html = ""
for scope in scopes:
description = SCOPE_DESCRIPTIONS.get(scope, scope)
scopes_html += f"""
<div class="scope-item">
{_check_icon()}
<span class="scope-text">{html.escape(description)}</span>
</div>
"""
# Build footer links (escape URLs)
footer_links = []
if privacy_policy_url:
footer_links.append(
f'<a href="{html.escape(privacy_policy_url)}" target="_blank">Privacy Policy</a>'
)
if terms_url:
footer_links.append(
f'<a href="{html.escape(terms_url)}" target="_blank">Terms of Service</a>'
)
footer_html = " &bull; ".join(footer_links) if footer_links else ""
# Escape action_url and consent_token
safe_action_url = html.escape(action_url)
safe_consent_token = html.escape(consent_token)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorize {safe_client_name} - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="logo">{logo_html}</div>
<h1>Authorize <span class="app-name">{safe_client_name}</span></h1>
<p class="subtitle">wants to access your AutoGPT account</p>
</div>
<div class="divider"></div>
<div class="scopes-section">
<h2>This will allow {safe_client_name} to:</h2>
{scopes_html}
</div>
<form method="POST" action="{safe_action_url}">
<input type="hidden" name="consent_token" value="{safe_consent_token}">
<div class="buttons">
<button type="submit" name="authorize" value="false" class="btn btn-cancel">
Cancel
</button>
<button type="submit" name="authorize" value="true" class="btn btn-allow">
Allow
</button>
</div>
</form>
{f'<div class="footer">{footer_html}</div>' if footer_html else ''}
</div>
</body>
</html>
"""
def render_error_page(
error: str,
error_description: str,
redirect_url: Optional[str] = None,
) -> str:
"""
Render an OAuth error page.
Args:
error: Error code
error_description: Human-readable error description
redirect_url: Optional URL to redirect back (if safe)
Returns:
HTML string for the error page
"""
# Escape user-provided values to prevent XSS
safe_error = html.escape(error)
safe_error_description = html.escape(error_description)
redirect_html = ""
if redirect_url:
safe_redirect_url = html.escape(redirect_url)
redirect_html = f"""
<a href="{safe_redirect_url}" class="btn btn-cancel" style="display: inline-block; text-decoration: none;">
Go Back
</a>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorization Error - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
{_error_icon()}
<h1 class="error-title">Authorization Failed</h1>
<p class="error-message">{safe_error_description}</p>
<p class="error-message" style="font-size: 12px; color: #52525b;">
Error code: {safe_error}
</p>
{redirect_html}
</div>
</div>
</body>
</html>
"""
def render_success_page(
message: str,
redirect_origin: Optional[str] = None,
post_message_data: Optional[dict] = None,
) -> str:
"""
Render a success page, optionally with postMessage for popup flows.
Args:
message: Success message to display
redirect_origin: Origin for postMessage (popup flows)
post_message_data: Data to send via postMessage (popup flows)
Returns:
HTML string for the success page
"""
# Escape user-provided values to prevent XSS
safe_message = html.escape(message)
# PostMessage script for popup flows
post_message_script = ""
if redirect_origin and post_message_data:
import json
# json.dumps escapes for JS context, but we also escape < > for HTML context
safe_json_origin = (
json.dumps(redirect_origin).replace("<", "\\u003c").replace(">", "\\u003e")
)
safe_json_data = (
json.dumps(post_message_data)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
)
post_message_script = f"""
<script>
(function() {{
var targetOrigin = {safe_json_origin};
var message = {safe_json_data};
if (window.opener) {{
window.opener.postMessage(message, targetOrigin);
setTimeout(function() {{ window.close(); }}, 1000);
}}
}})();
</script>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorization Successful - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
{_success_icon()}
<h1 class="success-title">Success!</h1>
<p class="error-message">{safe_message}</p>
<p class="error-message" style="font-size: 12px;">
This window will close automatically...
</p>
</div>
</div>
{post_message_script}
</body>
</html>
"""
def render_login_redirect_page(login_url: str) -> str:
"""
Render a page that redirects to login.
Args:
login_url: URL to redirect to for login
Returns:
HTML string with auto-redirect
"""
# Escape URL to prevent XSS
safe_login_url = html.escape(login_url)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="refresh" content="0;url={safe_login_url}">
<title>Login Required - AutoGPT</title>
<style>{_base_styles()}</style>
</head>
<body>
<div class="container">
<div class="error-container">
<p class="error-message">Redirecting to login...</p>
<a href="{safe_login_url}" class="btn btn-allow" style="display: inline-block; text-decoration: none;">
Click here if not redirected
</a>
</div>
</div>
</body>
</html>
"""
def _login_form_styles() -> str:
"""Additional CSS styles for login form."""
return """
.form-group {
margin-bottom: 16px;
}
.form-group label {
display: block;
font-size: 14px;
font-weight: 500;
color: #a1a1aa;
margin-bottom: 8px;
}
.form-group input {
width: 100%;
padding: 12px 16px;
border-radius: 8px;
border: 1px solid #3f3f46;
background: #18181b;
color: #e4e4e7;
font-size: 14px;
outline: none;
transition: border-color 0.2s;
}
.form-group input:focus {
border-color: #22d3ee;
}
.form-group input::placeholder {
color: #52525b;
}
.error-alert {
background: rgba(239, 68, 68, 0.1);
border: 1px solid #ef4444;
border-radius: 8px;
padding: 12px 16px;
margin-bottom: 16px;
color: #fca5a5;
font-size: 14px;
}
.btn-login {
width: 100%;
padding: 12px 24px;
border-radius: 8px;
font-size: 14px;
font-weight: 500;
cursor: pointer;
border: none;
background: #22d3ee;
color: #0f172a;
transition: all 0.2s;
margin-top: 8px;
}
.btn-login:hover {
background: #06b6d4;
}
.btn-login:disabled {
background: #3f3f46;
color: #71717a;
cursor: not-allowed;
}
.signup-link {
text-align: center;
margin-top: 16px;
font-size: 14px;
color: #a1a1aa;
}
.signup-link a {
color: #22d3ee;
text-decoration: none;
}
.signup-link a:hover {
text-decoration: underline;
}
"""
def render_login_page(
action_url: str,
login_state: str,
client_name: Optional[str] = None,
error_message: Optional[str] = None,
signup_url: Optional[str] = None,
browser_login_url: Optional[str] = None,
) -> str:
"""
Render an embedded login page for OAuth flow.
Args:
action_url: URL to submit the login form
login_state: State token to preserve OAuth parameters
client_name: Name of the application requesting access (optional)
error_message: Error message to display (optional)
signup_url: URL to signup page (optional)
browser_login_url: URL to redirect to frontend login (optional)
Returns:
HTML string for the login page
"""
# Escape all user-provided values to prevent XSS
safe_action_url = html.escape(action_url)
safe_login_state = html.escape(login_state)
safe_client_name = html.escape(client_name) if client_name else None
error_html = ""
if error_message:
safe_error_message = html.escape(error_message)
error_html = f'<div class="error-alert">{safe_error_message}</div>'
subtitle = "wants to access your AutoGPT account" if safe_client_name else ""
title_html = (
'<h1>Sign in to <span class="app-name">AutoGPT</span></h1>'
if not safe_client_name
else f'<h1><span class="app-name">{safe_client_name}</span></h1>'
)
signup_html = ""
if signup_url:
safe_signup_url = html.escape(signup_url)
signup_html = f"""
<div class="signup-link">
Don't have an account? <a href="{safe_signup_url}">Sign up</a>
</div>
"""
browser_login_html = ""
if browser_login_url:
safe_browser_login_url = html.escape(browser_login_url)
browser_login_html = f"""
<div class="divider"></div>
<div class="signup-link">
<a href="{safe_browser_login_url}">Sign in with Google or other providers</a>
</div>
"""
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Sign In - AutoGPT</title>
<style>
{_base_styles()}
{_login_form_styles()}
</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="logo">
<span class="logo-placeholder">A</span>
</div>
{title_html}
<p class="subtitle">{subtitle}</p>
</div>
<div class="divider"></div>
{error_html}
<form method="POST" action="{safe_action_url}">
<input type="hidden" name="login_state" value="{safe_login_state}">
<div class="form-group">
<label for="email">Email</label>
<input
type="email"
id="email"
name="email"
placeholder="you@example.com"
required
autocomplete="email"
>
</div>
<div class="form-group">
<label for="password">Password</label>
<input
type="password"
id="password"
name="password"
placeholder="Enter your password"
required
autocomplete="current-password"
>
</div>
<button type="submit" class="btn-login">Sign In</button>
</form>
{signup_html}
{browser_login_html}
</div>
</body>
</html>
"""

View File

@@ -1,71 +0,0 @@
"""
OIDC Discovery endpoints.
Implements:
- GET /.well-known/openid-configuration - OIDC Discovery Document
- GET /.well-known/jwks.json - JSON Web Key Set
"""
from fastapi import APIRouter
from backend.server.oauth.models import JWKS, OpenIDConfiguration
from backend.server.oauth.token_service import get_token_service
from backend.util.settings import Settings
discovery_router = APIRouter(tags=["oidc-discovery"])
@discovery_router.get(
"/.well-known/openid-configuration",
response_model=OpenIDConfiguration,
)
async def openid_configuration() -> OpenIDConfiguration:
"""
OIDC Discovery Document.
Returns metadata about the OAuth 2.0 authorization server including
endpoints, supported features, and algorithms.
"""
settings = Settings()
base_url = settings.config.platform_base_url or "https://platform.agpt.co"
return OpenIDConfiguration(
issuer=base_url,
authorization_endpoint=f"{base_url}/oauth/authorize",
token_endpoint=f"{base_url}/oauth/token",
userinfo_endpoint=f"{base_url}/oauth/userinfo",
revocation_endpoint=f"{base_url}/oauth/revoke",
jwks_uri=f"{base_url}/.well-known/jwks.json",
scopes_supported=[
"openid",
"profile",
"email",
"integrations:list",
"integrations:connect",
"integrations:delete",
"agents:execute",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=[
"client_secret_post",
"client_secret_basic",
"none", # For public clients with PKCE
],
code_challenge_methods_supported=["S256"],
subject_types_supported=["public"],
id_token_signing_alg_values_supported=["RS256"],
)
@discovery_router.get("/.well-known/jwks.json", response_model=JWKS)
async def jwks() -> dict:
"""
JSON Web Key Set (JWKS).
Returns the public key(s) used to verify JWT signatures.
External applications can use these keys to verify access tokens
and ID tokens issued by this authorization server.
"""
token_service = get_token_service()
return token_service.get_jwks()

View File

@@ -1,162 +0,0 @@
"""
OAuth 2.0 Error Responses (RFC 6749 Section 5.2).
"""
from enum import Enum
from typing import Optional
from urllib.parse import urlencode
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
class OAuthErrorCode(str, Enum):
"""Standard OAuth 2.0 error codes."""
# Authorization endpoint errors (RFC 6749 Section 4.1.2.1)
INVALID_REQUEST = "invalid_request"
UNAUTHORIZED_CLIENT = "unauthorized_client"
ACCESS_DENIED = "access_denied"
UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type"
INVALID_SCOPE = "invalid_scope"
SERVER_ERROR = "server_error"
TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"
# Token endpoint errors (RFC 6749 Section 5.2)
INVALID_CLIENT = "invalid_client"
INVALID_GRANT = "invalid_grant"
UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"
# Extension errors
LOGIN_REQUIRED = "login_required"
CONSENT_REQUIRED = "consent_required"
class OAuthErrorResponse(BaseModel):
"""OAuth error response model."""
error: str
error_description: Optional[str] = None
error_uri: Optional[str] = None
class OAuthError(Exception):
"""Base OAuth error exception."""
def __init__(
self,
error: OAuthErrorCode,
description: Optional[str] = None,
uri: Optional[str] = None,
state: Optional[str] = None,
):
self.error = error
self.description = description
self.uri = uri
self.state = state
super().__init__(description or error.value)
def to_response(self) -> OAuthErrorResponse:
"""Convert to response model."""
return OAuthErrorResponse(
error=self.error.value,
error_description=self.description,
error_uri=self.uri,
)
def to_redirect(self, redirect_uri: str) -> RedirectResponse:
"""Convert to redirect response with error in query params."""
params = {"error": self.error.value}
if self.description:
params["error_description"] = self.description
if self.uri:
params["error_uri"] = self.uri
if self.state:
params["state"] = self.state
separator = "&" if "?" in redirect_uri else "?"
url = f"{redirect_uri}{separator}{urlencode(params)}"
return RedirectResponse(url=url, status_code=302)
def to_http_exception(self, status_code: int = 400) -> HTTPException:
"""Convert to FastAPI HTTPException."""
return HTTPException(
status_code=status_code,
detail=self.to_response().model_dump(exclude_none=True),
)
# Convenience error classes
class InvalidRequestError(OAuthError):
"""The request is missing a required parameter or is otherwise malformed."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.INVALID_REQUEST, description, state=state)
class UnauthorizedClientError(OAuthError):
"""The client is not authorized to request an authorization code."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.UNAUTHORIZED_CLIENT, description, state=state)
class AccessDeniedError(OAuthError):
"""The resource owner denied the request."""
def __init__(self, description: str = "Access denied", state: Optional[str] = None):
super().__init__(OAuthErrorCode.ACCESS_DENIED, description, state=state)
class InvalidScopeError(OAuthError):
"""The requested scope is invalid, unknown, or malformed."""
def __init__(self, description: str, state: Optional[str] = None):
super().__init__(OAuthErrorCode.INVALID_SCOPE, description, state=state)
class InvalidClientError(OAuthError):
"""Client authentication failed."""
def __init__(self, description: str = "Invalid client"):
super().__init__(OAuthErrorCode.INVALID_CLIENT, description)
class InvalidGrantError(OAuthError):
"""The provided authorization code or refresh token is invalid."""
def __init__(self, description: str = "Invalid grant"):
super().__init__(OAuthErrorCode.INVALID_GRANT, description)
class UnsupportedGrantTypeError(OAuthError):
"""The authorization grant type is not supported."""
def __init__(self, grant_type: str):
super().__init__(
OAuthErrorCode.UNSUPPORTED_GRANT_TYPE,
f"Grant type '{grant_type}' is not supported",
)
class LoginRequiredError(OAuthError):
"""User must be logged in to complete the request."""
def __init__(self, state: Optional[str] = None):
super().__init__(
OAuthErrorCode.LOGIN_REQUIRED,
"User authentication required",
state=state,
)
class ConsentRequiredError(OAuthError):
"""User consent is required for the requested scopes."""
def __init__(self, state: Optional[str] = None):
super().__init__(
OAuthErrorCode.CONSENT_REQUIRED,
"User consent required",
state=state,
)

View File

@@ -1,288 +0,0 @@
"""
Pydantic models for OAuth 2.0 requests and responses.
"""
from datetime import datetime
from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel, Field, HttpUrl
# ============================================================
# Enums and Constants
# ============================================================
class OAuthScope(str, Enum):
"""Supported OAuth scopes."""
# OpenID Connect standard scopes
OPENID = "openid"
PROFILE = "profile"
EMAIL = "email"
# AutoGPT-specific scopes
INTEGRATIONS_LIST = "integrations:list"
INTEGRATIONS_CONNECT = "integrations:connect"
INTEGRATIONS_DELETE = "integrations:delete"
AGENTS_EXECUTE = "agents:execute"
SCOPE_DESCRIPTIONS: dict[str, str] = {
OAuthScope.OPENID.value: "Access your user ID",
OAuthScope.PROFILE.value: "Access your profile information (name)",
OAuthScope.EMAIL.value: "Access your email address",
OAuthScope.INTEGRATIONS_LIST.value: "View your connected integrations",
OAuthScope.INTEGRATIONS_CONNECT.value: "Connect new integrations on your behalf",
OAuthScope.INTEGRATIONS_DELETE.value: "Delete integrations on your behalf",
OAuthScope.AGENTS_EXECUTE.value: "Run agents on your behalf",
}
# ============================================================
# Authorization Request/Response Models
# ============================================================
class AuthorizationRequest(BaseModel):
"""OAuth 2.0 Authorization Request (RFC 6749 Section 4.1.1)."""
response_type: Literal["code"] = Field(
..., description="Must be 'code' for authorization code flow"
)
client_id: str = Field(..., description="Client identifier")
redirect_uri: str = Field(..., description="Redirect URI after authorization")
scope: str = Field(default="", description="Space-separated list of scopes")
state: str = Field(..., description="CSRF protection token (required)")
code_challenge: str = Field(..., description="PKCE code challenge (required)")
code_challenge_method: Literal["S256"] = Field(
default="S256", description="PKCE method (only S256 supported)"
)
nonce: Optional[str] = Field(None, description="OIDC nonce for replay protection")
prompt: Optional[Literal["consent", "login", "none"]] = Field(
None, description="Prompt behavior"
)
class ConsentFormData(BaseModel):
"""Consent form submission data."""
consent_token: str = Field(..., description="CSRF token for consent")
authorize: bool = Field(..., description="Whether user authorized")
# ============================================================
# Token Request/Response Models
# ============================================================
class TokenRequest(BaseModel):
"""OAuth 2.0 Token Request (RFC 6749 Section 4.1.3)."""
grant_type: Literal["authorization_code", "refresh_token"] = Field(
..., description="Grant type"
)
code: Optional[str] = Field(
None, description="Authorization code (for authorization_code grant)"
)
redirect_uri: Optional[str] = Field(
None, description="Must match authorization request"
)
client_id: str = Field(..., description="Client identifier")
client_secret: Optional[str] = Field(
None, description="Client secret (for confidential clients)"
)
code_verifier: Optional[str] = Field(
None, description="PKCE code verifier (for authorization_code grant)"
)
refresh_token: Optional[str] = Field(
None, description="Refresh token (for refresh_token grant)"
)
scope: Optional[str] = Field(
None, description="Requested scopes (for refresh_token grant)"
)
class TokenResponse(BaseModel):
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
access_token: str = Field(..., description="Access token")
token_type: Literal["Bearer"] = Field(default="Bearer", description="Token type")
expires_in: int = Field(..., description="Token lifetime in seconds")
refresh_token: Optional[str] = Field(None, description="Refresh token")
scope: Optional[str] = Field(None, description="Granted scopes")
id_token: Optional[str] = Field(None, description="OIDC ID token")
# ============================================================
# UserInfo Response Model
# ============================================================
class UserInfoResponse(BaseModel):
"""OIDC UserInfo Response."""
sub: str = Field(..., description="User ID (subject)")
email: Optional[str] = Field(None, description="User email")
email_verified: Optional[bool] = Field(
None, description="Whether email is verified"
)
name: Optional[str] = Field(None, description="User display name")
updated_at: Optional[int] = Field(None, description="Last profile update timestamp")
# ============================================================
# OIDC Discovery Models
# ============================================================
class OpenIDConfiguration(BaseModel):
"""OIDC Discovery Document."""
issuer: str
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: str
revocation_endpoint: str
jwks_uri: str
scopes_supported: list[str]
response_types_supported: list[str]
grant_types_supported: list[str]
token_endpoint_auth_methods_supported: list[str]
code_challenge_methods_supported: list[str]
subject_types_supported: list[str]
id_token_signing_alg_values_supported: list[str]
class JWK(BaseModel):
"""JSON Web Key."""
kty: str = Field(..., description="Key type (RSA)")
use: str = Field(default="sig", description="Key use (signature)")
kid: str = Field(..., description="Key ID")
alg: str = Field(default="RS256", description="Algorithm")
n: str = Field(..., description="RSA modulus")
e: str = Field(..., description="RSA exponent")
class JWKS(BaseModel):
"""JSON Web Key Set."""
keys: list[JWK]
# ============================================================
# Client Management Models
# ============================================================
class RegisterClientRequest(BaseModel):
"""Request to register a new OAuth client."""
name: str = Field(..., min_length=1, max_length=100, description="Client name")
description: Optional[str] = Field(
None, max_length=500, description="Client description"
)
logo_url: Optional[HttpUrl] = Field(None, description="Logo URL")
homepage_url: Optional[HttpUrl] = Field(None, description="Homepage URL")
privacy_policy_url: Optional[HttpUrl] = Field(
None, description="Privacy policy URL"
)
terms_of_service_url: Optional[HttpUrl] = Field(
None, description="Terms of service URL"
)
redirect_uris: list[str] = Field(
..., min_length=1, description="Allowed redirect URIs"
)
client_type: Literal["public", "confidential"] = Field(
default="public", description="Client type"
)
webhook_domains: list[str] = Field(
default_factory=list, description="Allowed webhook domains"
)
class UpdateClientRequest(BaseModel):
"""Request to update an OAuth client."""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
logo_url: Optional[HttpUrl] = None
homepage_url: Optional[HttpUrl] = None
privacy_policy_url: Optional[HttpUrl] = None
terms_of_service_url: Optional[HttpUrl] = None
redirect_uris: Optional[list[str]] = None
webhook_domains: Optional[list[str]] = None
class ClientResponse(BaseModel):
"""OAuth client response."""
id: str
client_id: str
client_type: str
name: str
description: Optional[str]
logo_url: Optional[str]
homepage_url: Optional[str]
privacy_policy_url: Optional[str]
terms_of_service_url: Optional[str]
redirect_uris: list[str]
allowed_scopes: list[str]
webhook_domains: list[str]
status: str
created_at: datetime
updated_at: datetime
class ClientSecretResponse(BaseModel):
"""Response containing newly generated client credentials."""
client_id: str
client_secret: str = Field(
..., description="Client secret (only shown once, store securely)"
)
webhook_secret: str = Field(
...,
description="Webhook secret for HMAC signing (only shown once, store securely)",
)
# ============================================================
# Token Introspection/Revocation Models
# ============================================================
class TokenRevocationRequest(BaseModel):
"""Token revocation request (RFC 7009)."""
token: str = Field(..., description="Token to revoke")
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
None, description="Hint about token type"
)
class TokenIntrospectionRequest(BaseModel):
"""Token introspection request (RFC 7662)."""
token: str = Field(..., description="Token to introspect")
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
None, description="Hint about token type"
)
class TokenIntrospectionResponse(BaseModel):
"""Token introspection response."""
active: bool = Field(..., description="Whether the token is active")
scope: Optional[str] = Field(None, description="Token scopes")
client_id: Optional[str] = Field(
None, description="Client that token was issued to"
)
username: Optional[str] = Field(None, description="User identifier")
token_type: Optional[str] = Field(None, description="Token type")
exp: Optional[int] = Field(None, description="Expiration timestamp")
iat: Optional[int] = Field(None, description="Issued at timestamp")
sub: Optional[str] = Field(None, description="Subject (user ID)")
aud: Optional[str] = Field(None, description="Audience")
iss: Optional[str] = Field(None, description="Issuer")

View File

@@ -1,66 +0,0 @@
"""
PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0.
RFC 7636: https://tools.ietf.org/html/rfc7636
"""
import base64
import hashlib
import secrets
def generate_code_verifier(length: int = 64) -> str:
"""
Generate a cryptographically random code verifier.
Args:
length: Length of the verifier (43-128 characters, default 64)
Returns:
URL-safe base64 encoded random string
"""
if not 43 <= length <= 128:
raise ValueError("Code verifier length must be between 43 and 128")
return secrets.token_urlsafe(length)[:length]
def generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""
Generate a code challenge from the verifier.
Args:
verifier: The code verifier string
method: Challenge method ("S256" or "plain")
Returns:
The code challenge string
"""
if method == "S256":
digest = hashlib.sha256(verifier.encode("ascii")).digest()
# URL-safe base64 encoding without padding
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
def verify_code_challenge(
verifier: str,
challenge: str,
method: str = "S256",
) -> bool:
"""
Verify that a code verifier matches the stored challenge.
Args:
verifier: The code verifier from the token request
challenge: The code challenge stored during authorization
method: The challenge method used
Returns:
True if the verifier matches the challenge
"""
expected = generate_code_challenge(verifier, method)
# Use constant-time comparison to prevent timing attacks
return secrets.compare_digest(expected, challenge)

View File

@@ -1,860 +0,0 @@
"""
OAuth 2.0 Authorization Server endpoints.
Implements:
- GET /oauth/authorize - Authorization endpoint
- POST /oauth/authorize/consent - Consent form submission
- POST /oauth/token - Token endpoint
- GET /oauth/userinfo - OIDC UserInfo endpoint
- POST /oauth/revoke - Token revocation endpoint
Authentication:
- X-API-Key header - API key for external apps (preferred)
- Authorization: Bearer <jwt> - JWT token authentication
- access_token cookie - Browser-based auth
"""
import json
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
from urllib.parse import urlencode
from fastapi import APIRouter, Form, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from backend.data.db import prisma
from backend.data.redis_client import get_redis_async
from backend.server.oauth.consent_templates import (
render_consent_page,
render_error_page,
render_login_redirect_page,
)
from backend.server.oauth.errors import (
InvalidClientError,
InvalidRequestError,
OAuthError,
UnsupportedGrantTypeError,
)
from backend.server.oauth.models import TokenResponse, UserInfoResponse
from backend.server.oauth.service import get_oauth_service
from backend.server.oauth.token_service import get_token_service
from backend.util.rate_limiter import check_rate_limit
logger = logging.getLogger(__name__)
oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
# Redis key prefix and TTL for consent state storage
CONSENT_STATE_PREFIX = "oauth:consent:"
CONSENT_STATE_TTL = 600 # 10 minutes
# Redis key prefix and TTL for login redirect state storage
LOGIN_STATE_PREFIX = "oauth:login:"
LOGIN_STATE_TTL = 900 # 15 minutes (longer to allow time for login)
async def _store_login_state(token: str, state: dict) -> None:
"""Store OAuth login state in Redis with TTL."""
redis = await get_redis_async()
await redis.setex(
f"{LOGIN_STATE_PREFIX}{token}",
LOGIN_STATE_TTL,
json.dumps(state, default=str),
)
async def _get_and_delete_login_state(token: str) -> Optional[dict]:
"""Retrieve and delete login state from Redis (one-time use, atomic)."""
redis = await get_redis_async()
key = f"{LOGIN_STATE_PREFIX}{token}"
# Use GETDEL for atomic get+delete to prevent race conditions
state_json = await redis.getdel(key)
if state_json:
return json.loads(state_json)
return None
async def _store_consent_state(token: str, state: dict) -> None:
"""Store consent state in Redis with TTL."""
redis = await get_redis_async()
await redis.setex(
f"{CONSENT_STATE_PREFIX}{token}",
CONSENT_STATE_TTL,
json.dumps(state, default=str),
)
async def _get_and_delete_consent_state(token: str) -> Optional[dict]:
"""Retrieve and delete consent state from Redis (atomic get+delete)."""
redis = await get_redis_async()
key = f"{CONSENT_STATE_PREFIX}{token}"
# Use GETDEL for atomic get+delete to prevent race conditions
state_json = await redis.getdel(key)
if state_json:
return json.loads(state_json)
return None
async def _get_user_id_from_request(
request: Request, strict_bearer: bool = False
) -> Optional[str]:
"""
Extract user ID from request, checking API key, Authorization header, and cookie.
Supports:
1. X-API-Key header - API key authentication (preferred for external apps)
2. Authorization: Bearer <jwt> - JWT token authentication
3. access_token cookie - Cookie-based auth (for browser flows)
Args:
request: The incoming request
strict_bearer: If True and Bearer token is provided but invalid,
do NOT fallthrough to cookie auth (prevents auth downgrade attacks)
"""
from autogpt_libs.auth.jwt_utils import parse_jwt_token
from backend.data.api_key import validate_api_key
# First try X-API-Key header (for external apps)
api_key = request.headers.get("X-API-Key")
if api_key:
try:
api_key_info = await validate_api_key(api_key)
if api_key_info:
return api_key_info.user_id
except Exception:
logger.debug("API key validation failed")
# Then try Authorization header (JWT)
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
try:
token = auth_header[7:]
payload = parse_jwt_token(token)
return payload.get("sub")
except Exception as e:
logger.debug("JWT token validation failed: %s", type(e).__name__)
# Security fix: If Bearer token was provided but invalid,
# don't fallthrough to weaker auth methods when strict_bearer is True
if strict_bearer:
return None
# Finally try cookie (browser-based auth)
token = request.cookies.get("access_token")
if token:
try:
payload = parse_jwt_token(token)
return payload.get("sub")
except Exception as e:
logger.debug("Cookie token validation failed: %s", type(e).__name__)
return None
def _parse_scopes(scope_str: str) -> list[str]:
"""Parse space-separated scope string into list."""
if not scope_str:
return []
return [s.strip() for s in scope_str.split() if s.strip()]
def _get_client_ip(request: Request) -> str:
"""Get client IP address from request."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
# ================================================================
# Authorization Endpoint
# ================================================================
@oauth_router.get("/authorize", response_model=None)
async def authorize(
request: Request,
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="Client identifier"),
redirect_uri: str = Query(..., description="Redirect URI"),
state: str = Query(..., description="CSRF state parameter"),
code_challenge: str = Query(..., description="PKCE code challenge"),
code_challenge_method: str = Query("S256", description="PKCE method"),
scope: str = Query("", description="Space-separated scopes"),
nonce: Optional[str] = Query(None, description="OIDC nonce"),
prompt: Optional[str] = Query(None, description="Prompt behavior"),
) -> HTMLResponse | RedirectResponse:
"""
OAuth 2.0 Authorization Endpoint.
Validates the request, checks user authentication, and either:
- Returns error if user is not authenticated (API key or JWT required)
- Shows consent page if user hasn't authorized these scopes
- Redirects with authorization code if already authorized
Authentication methods (in order of preference):
1. X-API-Key header - API key for external apps
2. Authorization: Bearer <jwt> - JWT token
3. access_token cookie - Browser-based auth
"""
# Get user ID from API key, Authorization header, or cookie
user_id = await _get_user_id_from_request(request)
# Rate limiting - use client IP as identifier for authorize endpoint
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
if not rate_result.allowed:
return HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many authorization requests. Please try again later.",
),
status_code=429,
)
oauth_service = get_oauth_service()
try:
# Validate response_type
if response_type != "code":
raise InvalidRequestError(
"Only 'code' response_type is supported", state=state
)
# Validate PKCE method
if code_challenge_method != "S256":
raise InvalidRequestError(
"Only 'S256' code_challenge_method is supported", state=state
)
# Parse scopes
scopes = _parse_scopes(scope)
# Validate client and redirect URI
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
# Check if user is authenticated
if not user_id:
# User needs to log in - store OAuth params and redirect to frontend login
from backend.util.settings import Settings
settings = Settings()
login_token = secrets.token_urlsafe(32)
logger.info(f"Storing login state with token: {login_token}")
await _store_login_state(
login_token,
{
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"prompt": prompt,
"created_at": datetime.now(timezone.utc).isoformat(),
"expires_at": (
datetime.now(timezone.utc) + timedelta(seconds=LOGIN_STATE_TTL)
).isoformat(),
},
)
logger.info(f"Login state stored successfully for token: {login_token}")
# Build redirect URL to frontend login
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
return _add_security_headers(
HTMLResponse(
render_error_page(
"server_error", "Frontend URL not configured"
),
status_code=500,
)
)
# Redirect to frontend login with oauth_session parameter
login_url = f"{frontend_base_url}/login?oauth_session={login_token}"
return _add_security_headers(
HTMLResponse(render_login_redirect_page(login_url))
)
# Check if user has already authorized these scopes
if prompt != "consent":
has_auth = await oauth_service.has_valid_authorization(
user_id, client_id, scopes
)
if has_auth:
# Skip consent, issue code directly
code = await oauth_service.create_authorization_code(
user_id=user_id,
client_id=client_id,
redirect_uri=redirect_uri,
scopes=scopes,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
nonce=nonce,
)
redirect_url = (
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
)
return RedirectResponse(url=redirect_url, status_code=302)
# Generate consent token and store state in Redis
consent_token = secrets.token_urlsafe(32)
await _store_consent_state(
consent_token,
{
"user_id": user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"expires_at": (
datetime.now(timezone.utc) + timedelta(minutes=10)
).isoformat(),
},
)
# Render consent page
return _add_security_headers(
HTMLResponse(
render_consent_page(
client_name=client.name,
client_logo=client.logoUrl,
scopes=scopes,
consent_token=consent_token,
action_url="/oauth/authorize/consent",
privacy_policy_url=client.privacyPolicyUrl,
terms_url=client.termsOfServiceUrl,
)
)
)
except OAuthError as e:
# If we have a valid redirect_uri, redirect with error
# Otherwise show error page
try:
client = await oauth_service.get_client(client_id)
if client and redirect_uri in client.redirectUris:
return e.to_redirect(redirect_uri)
except Exception:
pass
return _add_security_headers(
HTMLResponse(
render_error_page(e.error.value, e.description or "An error occurred"),
status_code=400,
)
)
@oauth_router.post("/authorize/consent", response_model=None)
async def submit_consent(
request: Request,
consent_token: str = Form(...),
authorize: str = Form(...),
) -> HTMLResponse | RedirectResponse:
"""
Process consent form submission.
Creates authorization code and redirects to client's redirect_uri.
"""
# Rate limiting on consent submission to prevent brute force attacks
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_consent")
if not rate_result.allowed:
return _add_security_headers(
HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many consent requests. Please try again later.",
),
status_code=429,
)
)
oauth_service = get_oauth_service()
# Validate consent token (retrieves and deletes from Redis atomically)
consent_state = await _get_and_delete_consent_state(consent_token)
if not consent_state:
return HTMLResponse(
render_error_page("invalid_request", "Invalid or expired consent token"),
status_code=400,
)
# Check expiration (expires_at is stored as ISO string in Redis)
expires_at = datetime.fromisoformat(consent_state["expires_at"])
if expires_at < datetime.now(timezone.utc):
return HTMLResponse(
render_error_page("invalid_request", "Consent session expired"),
status_code=400,
)
redirect_uri = consent_state["redirect_uri"]
state = consent_state["state"]
# Check if user denied
if authorize.lower() != "true":
error_params = urlencode(
{
"error": "access_denied",
"error_description": "User denied the authorization request",
"state": state,
}
)
return RedirectResponse(
url=f"{redirect_uri}?{error_params}",
status_code=302,
)
try:
# Create authorization code
code = await oauth_service.create_authorization_code(
user_id=consent_state["user_id"],
client_id=consent_state["client_id"],
redirect_uri=redirect_uri,
scopes=consent_state["scopes"],
code_challenge=consent_state["code_challenge"],
code_challenge_method=consent_state["code_challenge_method"],
nonce=consent_state["nonce"],
)
# Redirect with code
return RedirectResponse(
url=f"{redirect_uri}?{urlencode({'code': code, 'state': state})}",
status_code=302,
)
except OAuthError as e:
return e.to_redirect(redirect_uri)
def _wants_json(request: Request) -> bool:
"""Check if client prefers JSON response (for frontend fetch calls)."""
accept = request.headers.get("Accept", "")
return "application/json" in accept
def _add_security_headers(response: HTMLResponse) -> HTMLResponse:
"""Add security headers to OAuth HTML responses."""
response.headers["X-Frame-Options"] = "DENY"
response.headers["Content-Security-Policy"] = "frame-ancestors 'none'"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
@oauth_router.get("/authorize/resume", response_model=None)
async def resume_authorization(
request: Request,
session_id: str = Query(..., description="OAuth login session ID"),
) -> HTMLResponse | RedirectResponse | JSONResponse:
"""
Resume OAuth authorization after user login.
This endpoint is called after the user completes login on the frontend.
It retrieves the stored OAuth parameters and continues the authorization flow.
Supports Accept: application/json header to return JSON for frontend fetch calls,
solving CORS issues with redirect responses.
"""
wants_json = _wants_json(request)
# Rate limiting - use client IP
client_ip = _get_client_ip(request)
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
if not rate_result.allowed:
if wants_json:
return JSONResponse(
{
"error": "rate_limit_exceeded",
"error_description": "Too many requests",
},
status_code=429,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"rate_limit_exceeded",
"Too many authorization requests. Please try again later.",
),
status_code=429,
)
)
# Verify user is now authenticated (use strict_bearer to prevent auth downgrade)
user_id = await _get_user_id_from_request(request, strict_bearer=True)
if not user_id:
from backend.util.settings import Settings
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
if wants_json:
return JSONResponse(
{
"error": "login_required",
"error_description": "Authentication required",
"redirect_url": f"{frontend_url}/login",
},
status_code=401,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"login_required",
"Authentication required. Please log in and try again.",
redirect_url=f"{frontend_url}/login",
),
status_code=401,
)
)
# Retrieve and delete login state (one-time use)
logger.info(f"Attempting to retrieve login state for session_id: {session_id}")
login_state = await _get_and_delete_login_state(session_id)
if not login_state:
logger.warning(f"Login state not found for session_id: {session_id}")
if wants_json:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "Invalid or expired authorization session",
},
status_code=400,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"invalid_request",
"Invalid or expired authorization session. Please start over.",
),
status_code=400,
)
)
# Check expiration
expires_at = datetime.fromisoformat(login_state["expires_at"])
if expires_at < datetime.now(timezone.utc):
if wants_json:
return JSONResponse(
{
"error": "invalid_request",
"error_description": "Authorization session has expired",
},
status_code=400,
)
return _add_security_headers(
HTMLResponse(
render_error_page(
"invalid_request",
"Authorization session has expired. Please start over.",
),
status_code=400,
)
)
# Extract stored OAuth parameters
client_id = login_state["client_id"]
redirect_uri = login_state["redirect_uri"]
scopes = login_state["scopes"]
state = login_state["state"]
code_challenge = login_state["code_challenge"]
code_challenge_method = login_state["code_challenge_method"]
nonce = login_state.get("nonce")
prompt = login_state.get("prompt")
oauth_service = get_oauth_service()
try:
# Re-validate client (in case it was deactivated during login)
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
# Check if user has already authorized these scopes (skip consent if yes)
if prompt != "consent":
has_auth = await oauth_service.has_valid_authorization(
user_id, client_id, scopes
)
if has_auth:
# Skip consent, issue code directly
code = await oauth_service.create_authorization_code(
user_id=user_id,
client_id=client_id,
redirect_uri=redirect_uri,
scopes=scopes,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
nonce=nonce,
)
redirect_url = (
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
)
# Return JSON with redirect URL for frontend to handle
if wants_json:
return JSONResponse(
{"redirect_url": redirect_url, "needs_consent": False}
)
return RedirectResponse(url=redirect_url, status_code=302)
# Generate consent token and store state in Redis
consent_token = secrets.token_urlsafe(32)
await _store_consent_state(
consent_token,
{
"user_id": user_id,
"client_id": client_id,
"redirect_uri": redirect_uri,
"scopes": scopes,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"nonce": nonce,
"expires_at": (
datetime.now(timezone.utc) + timedelta(minutes=10)
).isoformat(),
},
)
# For JSON requests, return consent data instead of HTML
if wants_json:
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
scope_details = [
{"scope": s, "description": SCOPE_DESCRIPTIONS.get(s, s)}
for s in scopes
]
return JSONResponse(
{
"needs_consent": True,
"consent_token": consent_token,
"client": {
"name": client.name,
"logo_url": client.logoUrl,
"privacy_policy_url": client.privacyPolicyUrl,
"terms_url": client.termsOfServiceUrl,
},
"scopes": scope_details,
"action_url": "/oauth/authorize/consent",
}
)
# Render consent page (HTML response)
return _add_security_headers(
HTMLResponse(
render_consent_page(
client_name=client.name,
client_logo=client.logoUrl,
scopes=scopes,
consent_token=consent_token,
action_url="/oauth/authorize/consent",
privacy_policy_url=client.privacyPolicyUrl,
terms_url=client.termsOfServiceUrl,
)
)
)
except OAuthError as e:
if wants_json:
return JSONResponse(
{"error": e.error.value, "error_description": e.description},
status_code=400,
)
# If we have a valid redirect_uri, redirect with error
try:
client = await oauth_service.get_client(client_id)
if client and redirect_uri in client.redirectUris:
return e.to_redirect(redirect_uri)
except Exception:
pass
return _add_security_headers(
HTMLResponse(
render_error_page(e.error.value, e.description or "An error occurred"),
status_code=400,
)
)
# ================================================================
# Token Endpoint
# ================================================================
@oauth_router.post("/token", response_model=TokenResponse)
async def token(
request: Request,
grant_type: str = Form(...),
code: Optional[str] = Form(None),
redirect_uri: Optional[str] = Form(None),
client_id: str = Form(...),
client_secret: Optional[str] = Form(None),
code_verifier: Optional[str] = Form(None),
refresh_token: Optional[str] = Form(None),
scope: Optional[str] = Form(None),
) -> TokenResponse:
"""
OAuth 2.0 Token Endpoint.
Supports:
- authorization_code grant (with PKCE)
- refresh_token grant
"""
# Rate limiting - use client_id as identifier
rate_result = await check_rate_limit(client_id, "oauth_token")
if not rate_result.allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={
"Retry-After": str(int(rate_result.retry_after or 60)),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(rate_result.reset_at)),
},
)
oauth_service = get_oauth_service()
try:
# Validate client authentication
await oauth_service.validate_client_secret(client_id, client_secret)
if grant_type == "authorization_code":
# Validate required parameters
if not code:
raise InvalidRequestError("'code' is required")
if not redirect_uri:
raise InvalidRequestError("'redirect_uri' is required")
if not code_verifier:
raise InvalidRequestError("'code_verifier' is required for PKCE")
return await oauth_service.exchange_authorization_code(
code=code,
client_id=client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
)
elif grant_type == "refresh_token":
if not refresh_token:
raise InvalidRequestError("'refresh_token' is required")
requested_scopes = _parse_scopes(scope) if scope else None
return await oauth_service.refresh_access_token(
refresh_token=refresh_token,
client_id=client_id,
requested_scopes=requested_scopes,
)
else:
raise UnsupportedGrantTypeError(grant_type)
except OAuthError as e:
# 401 for client auth failure, 400 for other validation errors (per RFC 6749)
raise e.to_http_exception(401 if isinstance(e, InvalidClientError) else 400)
# ================================================================
# UserInfo Endpoint
# ================================================================
@oauth_router.get("/userinfo", response_model=UserInfoResponse)
async def userinfo(request: Request) -> UserInfoResponse:
"""
OIDC UserInfo Endpoint.
Returns user profile information based on the granted scopes.
"""
token_service = get_token_service()
# Extract bearer token
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Bearer token required",
headers={"WWW-Authenticate": "Bearer"},
)
token = auth_header[7:]
try:
# Verify token
claims = token_service.verify_access_token(token)
# Check token is not revoked
token_hash = token_service.hash_token(token)
stored_token = await prisma.oauthaccesstoken.find_unique(
where={"tokenHash": token_hash}
)
if not stored_token or stored_token.revokedAt:
raise HTTPException(
status_code=401,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
# Update last used
await prisma.oauthaccesstoken.update(
where={"id": stored_token.id},
data={"lastUsedAt": datetime.now(timezone.utc)},
)
# Get user info based on scopes
user = await prisma.user.find_unique(where={"id": claims.sub})
if not user:
raise HTTPException(status_code=404, detail="User not found")
scopes = claims.scope.split()
# Build response based on scopes
email = user.email if "email" in scopes else None
email_verified = user.emailVerified if "email" in scopes else None
name = user.name if "profile" in scopes else None
updated_at = int(user.updatedAt.timestamp()) if "profile" in scopes else None
return UserInfoResponse(
sub=claims.sub,
email=email,
email_verified=email_verified,
name=name,
updated_at=updated_at,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=401,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
# ================================================================
# Token Revocation Endpoint
# ================================================================
@oauth_router.post("/revoke")
async def revoke(
request: Request,
token: str = Form(...),
token_type_hint: Optional[str] = Form(None),
) -> JSONResponse:
"""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
Revokes an access token or refresh token.
"""
oauth_service = get_oauth_service()
# Note: Per RFC 7009, always return 200 even if token not found
await oauth_service.revoke_token(token, token_type_hint)
return JSONResponse(content={}, status_code=200)

View File

@@ -1,625 +0,0 @@
"""
Core OAuth 2.0 service logic.
Handles:
- Client validation and lookup
- Authorization code generation and exchange
- Token issuance and refresh
- User consent management
- Audit logging
"""
import hashlib
import json
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from prisma.enums import OAuthClientStatus
from prisma.models import OAuthAuthorization, OAuthClient, User
from backend.data.db import prisma
from backend.server.oauth.errors import (
InvalidClientError,
InvalidGrantError,
InvalidRequestError,
InvalidScopeError,
)
from backend.server.oauth.models import TokenResponse
from backend.server.oauth.pkce import verify_code_challenge
from backend.server.oauth.token_service import OAuthTokenService, get_token_service
class OAuthService:
"""Core OAuth 2.0 service."""
def __init__(self, token_service: Optional[OAuthTokenService] = None):
self.token_service = token_service or get_token_service()
# ================================================================
# Client Operations
# ================================================================
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
"""Get an OAuth client by client_id."""
return await prisma.oauthclient.find_unique(where={"clientId": client_id})
async def validate_client(
self,
client_id: str,
redirect_uri: str,
scopes: list[str],
) -> OAuthClient:
"""
Validate a client for authorization.
Args:
client_id: Client identifier
redirect_uri: Requested redirect URI
scopes: Requested scopes
Returns:
Validated OAuthClient
Raises:
InvalidClientError: Client not found or inactive
InvalidRequestError: Invalid redirect URI
InvalidScopeError: Invalid scopes requested
"""
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
if client.status != OAuthClientStatus.ACTIVE:
raise InvalidClientError(f"Client '{client_id}' is not active")
# Validate redirect URI (exact match required)
if redirect_uri not in client.redirectUris:
raise InvalidRequestError(
f"Redirect URI '{redirect_uri}' is not registered for this client"
)
# Validate scopes
invalid_scopes = set(scopes) - set(client.allowedScopes)
if invalid_scopes:
raise InvalidScopeError(
f"Scopes not allowed for this client: {', '.join(invalid_scopes)}"
)
return client
async def validate_client_secret(
self,
client_id: str,
client_secret: Optional[str],
) -> OAuthClient:
"""
Validate client authentication for token endpoint.
Args:
client_id: Client identifier
client_secret: Client secret (for confidential clients)
Returns:
Validated OAuthClient
Raises:
InvalidClientError: Invalid client or credentials
"""
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
if client.status != OAuthClientStatus.ACTIVE:
raise InvalidClientError(f"Client '{client_id}' is not active")
# Confidential clients must provide secret
if client.clientType == "confidential":
if not client_secret:
raise InvalidClientError("Client secret required")
# Hash and compare
secret_hash = self._hash_secret(
client_secret, client.clientSecretSalt or ""
)
if not secrets.compare_digest(secret_hash, client.clientSecretHash or ""):
raise InvalidClientError("Invalid client credentials")
return client
@staticmethod
def _hash_secret(secret: str, salt: str) -> str:
"""Hash a client secret with salt."""
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
# ================================================================
# Authorization Code Operations
# ================================================================
async def create_authorization_code(
self,
user_id: str,
client_id: str,
redirect_uri: str,
scopes: list[str],
code_challenge: str,
code_challenge_method: str = "S256",
nonce: Optional[str] = None,
) -> str:
"""
Create a new authorization code.
Args:
user_id: User who authorized
client_id: Client being authorized
redirect_uri: Redirect URI for callback
scopes: Granted scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method (S256)
nonce: OIDC nonce (optional)
Returns:
Authorization code string
"""
code = secrets.token_urlsafe(32)
code_hash = self.token_service.hash_token(code)
# Get the OAuthClient to link
client = await self.get_client(client_id)
if not client:
raise InvalidClientError(f"Client '{client_id}' not found")
await prisma.oauthauthorizationcode.create(
data={ # type: ignore[typeddict-item]
"codeHash": code_hash,
"userId": user_id,
"clientId": client.id,
"redirectUri": redirect_uri,
"scopes": scopes,
"codeChallenge": code_challenge,
"codeChallengeMethod": code_challenge_method,
"nonce": nonce,
"expiresAt": datetime.now(timezone.utc) + timedelta(minutes=10),
}
)
return code
async def exchange_authorization_code(
self,
code: str,
client_id: str,
redirect_uri: str,
code_verifier: str,
) -> TokenResponse:
"""
Exchange an authorization code for tokens.
Args:
code: Authorization code
client_id: Client identifier
redirect_uri: Must match original redirect URI
code_verifier: PKCE code verifier
Returns:
TokenResponse with access token, refresh token, etc.
Raises:
InvalidGrantError: Invalid or expired code
InvalidRequestError: PKCE verification failed
"""
code_hash = self.token_service.hash_token(code)
# Find the authorization code
auth_code = await prisma.oauthauthorizationcode.find_unique(
where={"codeHash": code_hash},
include={"Client": True, "User": True},
)
if not auth_code:
raise InvalidGrantError("Authorization code not found")
# Ensure Client relation is loaded
if not auth_code.Client:
raise InvalidGrantError("Authorization code client not found")
# Check if already used
if auth_code.usedAt:
# Code reuse is a security incident - revoke all tokens for this authorization
await self._revoke_tokens_for_client_user(
auth_code.Client.clientId, auth_code.userId
)
raise InvalidGrantError("Authorization code has already been used")
# Check expiration
if auth_code.expiresAt < datetime.now(timezone.utc):
raise InvalidGrantError("Authorization code has expired")
# Validate client
if auth_code.Client.clientId != client_id:
raise InvalidGrantError("Client ID mismatch")
# Validate redirect URI
if auth_code.redirectUri != redirect_uri:
raise InvalidGrantError("Redirect URI mismatch")
# Verify PKCE
if not verify_code_challenge(
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
):
raise InvalidRequestError("PKCE verification failed")
# Mark code as used
await prisma.oauthauthorizationcode.update(
where={"id": auth_code.id},
data={"usedAt": datetime.now(timezone.utc)},
)
# Create or update authorization record
await self._upsert_authorization(
auth_code.userId, auth_code.Client.id, auth_code.scopes
)
# Generate tokens
return await self._create_tokens(
user_id=auth_code.userId,
client=auth_code.Client,
scopes=auth_code.scopes,
nonce=auth_code.nonce,
user=auth_code.User,
)
async def refresh_access_token(
self,
refresh_token: str,
client_id: str,
requested_scopes: Optional[list[str]] = None,
) -> TokenResponse:
"""
Refresh an access token using a refresh token.
Args:
refresh_token: Refresh token string
client_id: Client identifier
requested_scopes: Optionally request fewer scopes
Returns:
New TokenResponse
Raises:
InvalidGrantError: Invalid or expired refresh token
"""
token_hash = self.token_service.hash_token(refresh_token)
# Find the refresh token
stored_token = await prisma.oauthrefreshtoken.find_unique(
where={"tokenHash": token_hash},
include={"Client": True, "User": True},
)
if not stored_token:
raise InvalidGrantError("Refresh token not found")
# Ensure Client relation is loaded
if not stored_token.Client:
raise InvalidGrantError("Refresh token client not found")
# Check if revoked
if stored_token.revokedAt:
raise InvalidGrantError("Refresh token has been revoked")
# Check expiration
if stored_token.expiresAt < datetime.now(timezone.utc):
raise InvalidGrantError("Refresh token has expired")
# Validate client
if stored_token.Client.clientId != client_id:
raise InvalidGrantError("Client ID mismatch")
# Determine scopes
scopes = stored_token.scopes
if requested_scopes:
# Can only request a subset of original scopes
invalid = set(requested_scopes) - set(stored_token.scopes)
if invalid:
raise InvalidScopeError(
f"Cannot request scopes not in original grant: {', '.join(invalid)}"
)
scopes = requested_scopes
# Generate new tokens (rotates refresh token)
return await self._create_tokens(
user_id=stored_token.userId,
client=stored_token.Client,
scopes=scopes,
user=stored_token.User,
old_refresh_token_id=stored_token.id,
)
# ================================================================
# Token Operations
# ================================================================
async def _create_tokens(
self,
user_id: str,
client: OAuthClient,
scopes: list[str],
user: Optional[User] = None,
nonce: Optional[str] = None,
old_refresh_token_id: Optional[str] = None,
) -> TokenResponse:
"""
Create access and refresh tokens.
Args:
user_id: User ID
client: OAuth client
scopes: Granted scopes
user: User object (for ID token claims)
nonce: OIDC nonce
old_refresh_token_id: ID of refresh token being rotated
Returns:
TokenResponse
"""
# Generate access token
access_token, access_expires_at = self.token_service.generate_access_token(
user_id=user_id,
client_id=client.clientId,
scopes=scopes,
expires_in=client.tokenLifetimeSecs,
)
# Store access token hash
await prisma.oauthaccesstoken.create(
data={ # type: ignore[typeddict-item]
"tokenHash": self.token_service.hash_token(access_token),
"userId": user_id,
"clientId": client.id,
"scopes": scopes,
"expiresAt": access_expires_at,
}
)
# Generate refresh token
refresh_token = self.token_service.generate_refresh_token()
refresh_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=client.refreshTokenLifetimeSecs
)
await prisma.oauthrefreshtoken.create(
data={ # type: ignore[typeddict-item]
"tokenHash": self.token_service.hash_token(refresh_token),
"userId": user_id,
"clientId": client.id,
"scopes": scopes,
"expiresAt": refresh_expires_at,
}
)
# Revoke old refresh token if rotating
if old_refresh_token_id:
await prisma.oauthrefreshtoken.update(
where={"id": old_refresh_token_id},
data={"revokedAt": datetime.now(timezone.utc)},
)
# Generate ID token if openid scope requested
id_token = None
if "openid" in scopes and user:
email = user.email if "email" in scopes else None
name = user.name if "profile" in scopes else None
id_token = self.token_service.generate_id_token(
user_id=user_id,
client_id=client.clientId,
email=email,
name=name,
nonce=nonce,
)
# Audit log
await self._audit_log(
event_type="token.issued",
user_id=user_id,
client_id=client.clientId,
details={"scopes": scopes},
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=client.tokenLifetimeSecs,
refresh_token=refresh_token,
scope=" ".join(scopes),
id_token=id_token,
)
async def revoke_token(
self,
token: str,
token_type_hint: Optional[str] = None,
) -> bool:
"""
Revoke an access or refresh token.
Args:
token: Token to revoke
token_type_hint: Hint about token type
Returns:
True if token was found and revoked
"""
token_hash = self.token_service.hash_token(token)
now = datetime.now(timezone.utc)
# Try refresh token first if hinted or no hint
if token_type_hint in (None, "refresh_token"):
result = await prisma.oauthrefreshtoken.update_many(
where={"tokenHash": token_hash, "revokedAt": None},
data={"revokedAt": now},
)
if result > 0:
return True
# Try access token
if token_type_hint in (None, "access_token"):
result = await prisma.oauthaccesstoken.update_many(
where={"tokenHash": token_hash, "revokedAt": None},
data={"revokedAt": now},
)
if result > 0:
return True
return False
async def _revoke_tokens_for_client_user(
self,
client_id: str,
user_id: str,
) -> None:
"""Revoke all tokens for a client-user pair (security incident response)."""
client = await self.get_client(client_id)
if not client:
return
now = datetime.now(timezone.utc)
await prisma.oauthaccesstoken.update_many(
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
data={"revokedAt": now},
)
await prisma.oauthrefreshtoken.update_many(
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
data={"revokedAt": now},
)
await self._audit_log(
event_type="tokens.revoked.security",
user_id=user_id,
client_id=client_id,
details={"reason": "authorization_code_reuse"},
)
# ================================================================
# Authorization (Consent) Operations
# ================================================================
async def get_authorization(
self,
user_id: str,
client_id: str,
) -> Optional[OAuthAuthorization]:
"""Get existing authorization for user-client pair."""
client = await self.get_client(client_id)
if not client:
return None
return await prisma.oauthauthorization.find_unique(
where={
"userId_clientId": {
"userId": user_id,
"clientId": client.id,
}
}
)
async def has_valid_authorization(
self,
user_id: str,
client_id: str,
scopes: list[str],
) -> bool:
"""
Check if user has already authorized these scopes for this client.
Args:
user_id: User ID
client_id: Client identifier
scopes: Requested scopes
Returns:
True if user has already authorized all requested scopes
"""
auth = await self.get_authorization(user_id, client_id)
if not auth or auth.revokedAt:
return False
# Check if all requested scopes are already authorized
return set(scopes).issubset(set(auth.scopes))
async def _upsert_authorization(
self,
user_id: str,
client_db_id: str,
scopes: list[str],
) -> None:
"""Create or update an authorization record."""
existing = await prisma.oauthauthorization.find_unique(
where={
"userId_clientId": {
"userId": user_id,
"clientId": client_db_id,
}
}
)
if existing:
# Merge scopes
merged_scopes = list(set(existing.scopes) | set(scopes))
await prisma.oauthauthorization.update(
where={"id": existing.id},
data={"scopes": merged_scopes, "revokedAt": None},
)
else:
await prisma.oauthauthorization.create(
data={ # type: ignore[typeddict-item]
"userId": user_id,
"clientId": client_db_id,
"scopes": scopes,
}
)
# ================================================================
# Audit Logging
# ================================================================
async def _audit_log(
self,
event_type: str,
user_id: Optional[str] = None,
client_id: Optional[str] = None,
grant_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[dict[str, Any]] = None,
) -> None:
"""Create an audit log entry."""
# Convert details to JSON for Prisma's Json field
details_json = json.dumps(details or {})
await prisma.oauthauditlog.create(
data={
"eventType": event_type,
"userId": user_id,
"clientId": client_id,
"grantId": grant_id,
"ipAddress": ip_address,
"userAgent": user_agent,
"details": json.loads(details_json), # type: ignore[arg-type]
}
)
# Module-level singleton
_oauth_service: Optional[OAuthService] = None
def get_oauth_service() -> OAuthService:
"""Get the singleton OAuth service instance."""
global _oauth_service
if _oauth_service is None:
_oauth_service = OAuthService()
return _oauth_service

View File

@@ -1,298 +0,0 @@
"""
JWT Token Service for OAuth 2.0 Provider.
Handles generation and validation of:
- Access tokens (JWT)
- Refresh tokens (opaque)
- ID tokens (JWT, OIDC)
"""
import base64
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
generate_private_key,
)
from pydantic import BaseModel
from backend.util.settings import Settings
class TokenClaims(BaseModel):
"""Decoded token claims."""
iss: str # Issuer
sub: str # Subject (user ID)
aud: str # Audience (client ID)
exp: int # Expiration timestamp
iat: int # Issued at timestamp
jti: str # JWT ID
scope: str # Space-separated scopes
client_id: str # Client ID
class OAuthTokenService:
"""
Service for generating and validating OAuth tokens.
Uses RS256 (RSA with SHA-256) for JWT signing.
"""
def __init__(self, settings: Optional[Settings] = None):
self._settings = settings or Settings()
self._private_key: Optional[RSAPrivateKey] = None
self._public_key: Optional[RSAPublicKey] = None
self._algorithm = "RS256"
@property
def issuer(self) -> str:
"""Get the token issuer URL."""
return self._settings.config.platform_base_url or "https://platform.agpt.co"
@property
def key_id(self) -> str:
"""Get the key ID for JWKS."""
return self._settings.secrets.oauth_jwt_key_id or "default-key-id"
def _get_private_key(self) -> RSAPrivateKey:
"""Load or generate the private key."""
if self._private_key is not None:
return self._private_key
key_pem = self._settings.secrets.oauth_jwt_private_key
if key_pem:
loaded_key = serialization.load_pem_private_key(
key_pem.encode(), password=None
)
if not isinstance(loaded_key, RSAPrivateKey):
raise ValueError("OAuth JWT private key must be RSA")
self._private_key = loaded_key
else:
# Generate a key for development (should not be used in production)
self._private_key = generate_private_key(
public_exponent=65537,
key_size=2048,
)
return self._private_key
def _get_public_key(self) -> RSAPublicKey:
"""Get the public key from the private key."""
if self._public_key is not None:
return self._public_key
key_pem = self._settings.secrets.oauth_jwt_public_key
if key_pem:
loaded_key = serialization.load_pem_public_key(key_pem.encode())
if not isinstance(loaded_key, RSAPublicKey):
raise ValueError("OAuth JWT public key must be RSA")
self._public_key = loaded_key
else:
self._public_key = self._get_private_key().public_key()
return self._public_key
def generate_access_token(
self,
user_id: str,
client_id: str,
scopes: list[str],
expires_in: int = 3600,
) -> tuple[str, datetime]:
"""
Generate a JWT access token.
Args:
user_id: User ID (subject)
client_id: Client ID (audience)
scopes: List of granted scopes
expires_in: Token lifetime in seconds
Returns:
Tuple of (token string, expiration datetime)
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=expires_in)
payload = {
"iss": self.issuer,
"sub": user_id,
"aud": client_id,
"exp": int(expires_at.timestamp()),
"iat": int(now.timestamp()),
"jti": secrets.token_urlsafe(16),
"scope": " ".join(scopes),
"client_id": client_id,
}
token = jwt.encode(
payload,
self._get_private_key(),
algorithm=self._algorithm,
headers={"kid": self.key_id},
)
return token, expires_at
def generate_refresh_token(self) -> str:
"""
Generate an opaque refresh token.
Returns:
URL-safe random token string
"""
return secrets.token_urlsafe(48)
def generate_id_token(
self,
user_id: str,
client_id: str,
email: Optional[str] = None,
name: Optional[str] = None,
nonce: Optional[str] = None,
expires_in: int = 3600,
) -> str:
"""
Generate an OIDC ID token.
Args:
user_id: User ID (subject)
client_id: Client ID (audience)
email: User's email (optional)
name: User's name (optional)
nonce: OIDC nonce for replay protection (optional)
expires_in: Token lifetime in seconds
Returns:
JWT ID token string
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=expires_in)
payload = {
"iss": self.issuer,
"sub": user_id,
"aud": client_id,
"exp": int(expires_at.timestamp()),
"iat": int(now.timestamp()),
"auth_time": int(now.timestamp()),
}
if email:
payload["email"] = email
payload["email_verified"] = True
if name:
payload["name"] = name
if nonce:
payload["nonce"] = nonce
return jwt.encode(
payload,
self._get_private_key(),
algorithm=self._algorithm,
headers={"kid": self.key_id},
)
def verify_access_token(
self,
token: str,
expected_client_id: Optional[str] = None,
) -> TokenClaims:
"""
Verify and decode a JWT access token.
Args:
token: JWT token string
expected_client_id: Expected client ID (audience)
Returns:
Decoded token claims
Raises:
jwt.ExpiredSignatureError: Token has expired
jwt.InvalidTokenError: Token is invalid
"""
options = {}
if expected_client_id:
options["audience"] = expected_client_id
payload = jwt.decode(
token,
self._get_public_key(),
algorithms=[self._algorithm],
issuer=self.issuer,
options={"verify_aud": bool(expected_client_id)},
**options,
)
return TokenClaims(
iss=payload["iss"],
sub=payload["sub"],
aud=payload.get("aud", payload.get("client_id", "")),
exp=payload["exp"],
iat=payload["iat"],
jti=payload["jti"],
scope=payload.get("scope", ""),
client_id=payload.get("client_id", payload.get("aud", "")),
)
@staticmethod
def hash_token(token: str) -> str:
"""
Hash a token for secure storage.
Args:
token: Token string to hash
Returns:
SHA-256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
def get_jwks(self) -> dict:
"""
Get the JSON Web Key Set (JWKS) for public key distribution.
Returns:
JWKS dictionary with public key(s)
"""
public_key = self._get_public_key()
public_numbers = public_key.public_numbers()
# Convert to base64url encoding without padding
def int_to_base64url(n: int, length: int) -> str:
data = n.to_bytes(length, byteorder="big")
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
# RSA modulus and exponent
n = int_to_base64url(public_numbers.n, (public_numbers.n.bit_length() + 7) // 8)
e = int_to_base64url(public_numbers.e, 3)
return {
"keys": [
{
"kty": "RSA",
"use": "sig",
"kid": self.key_id,
"alg": self._algorithm,
"n": n,
"e": e,
}
]
}
# Module-level singleton
_token_service: Optional[OAuthTokenService] = None
def get_token_service() -> OAuthTokenService:
"""Get the singleton token service instance."""
global _token_service
if _token_service is None:
_token_service = OAuthTokenService()
return _token_service

View File

@@ -21,7 +21,6 @@ import backend.data.db
import backend.data.graph import backend.data.graph
import backend.data.user import backend.data.user
import backend.integrations.webhooks.utils import backend.integrations.webhooks.utils
import backend.server.integrations.connect_router
import backend.server.routers.postmark.postmark import backend.server.routers.postmark.postmark
import backend.server.routers.v1 import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes import backend.server.v2.admin.credit_admin_routes
@@ -45,7 +44,6 @@ from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi from backend.monitoring.instrumentation import instrument_fastapi
from backend.server.external.api import external_app from backend.server.external.api import external_app
from backend.server.middleware.security import SecurityHeadersMiddleware from backend.server.middleware.security import SecurityHeadersMiddleware
from backend.server.oauth import client_router, discovery_router, oauth_router
from backend.server.utils.cors import build_cors_params from backend.server.utils.cors import build_cors_params
from backend.util import json from backend.util import json
from backend.util.cloud_storage import shutdown_cloud_storage_handler from backend.util.cloud_storage import shutdown_cloud_storage_handler
@@ -302,18 +300,6 @@ app.include_router(
app.mount("/external-api", external_app) app.mount("/external-api", external_app)
# OAuth Provider routes
app.include_router(oauth_router, tags=["oauth"], prefix="")
app.include_router(discovery_router, tags=["oidc-discovery"], prefix="")
app.include_router(client_router, tags=["oauth-clients"], prefix="")
# Integration Connect popup routes (for Credential Broker)
app.include_router(
backend.server.integrations.connect_router.connect_router,
tags=["integration-connect"],
prefix="",
)
@app.get(path="/health", tags=["health"], dependencies=[]) @app.get(path="/health", tags=["health"], dependencies=[])
async def health(): async def health():

View File

@@ -1,9 +1,16 @@
import logging import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Sequence
import prisma import prisma
import backend.data.block import backend.data.block
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.db as store_db
import backend.server.v2.store.model as store_model
from backend.blocks import load_all_blocks from backend.blocks import load_all_blocks
from backend.blocks.llm import LlmModel from backend.blocks.llm import LlmModel
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
@@ -14,17 +21,36 @@ from backend.server.v2.builder.model import (
BlockResponse, BlockResponse,
BlockType, BlockType,
CountResponse, CountResponse,
FilterType,
Provider, Provider,
ProviderResponse, ProviderResponse,
SearchBlocksResponse, SearchEntry,
) )
from backend.util.cache import cached from backend.util.cache import cached
from backend.util.models import Pagination from backend.util.models import Pagination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel] llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
_static_counts_cache: dict | None = None
_suggested_blocks: list[BlockInfo] | None = None MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@dataclass
class _ScoredItem:
item: SearchResultItem
filter_type: FilterType
score: float
sort_key: str
@dataclass
class _SearchCacheEntry:
items: list[SearchResultItem]
total_items: dict[FilterType, int]
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]: def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
@@ -130,71 +156,244 @@ def get_block_by_id(block_id: str) -> BlockInfo | None:
return None return None
def search_blocks( async def update_search(user_id: str, search: SearchEntry) -> str:
include_blocks: bool = True,
include_integrations: bool = True,
query: str = "",
page: int = 1,
page_size: int = 50,
) -> SearchBlocksResponse:
""" """
Get blocks based on the filter and query. Upsert a search request for the user and return the search ID.
`providers` only applies for `integrations` filter.
""" """
blocks: list[AnyBlockSchema] = [] if search.search_id:
query = query.lower() # Update existing search
await prisma.models.BuilderSearchHistory.prisma().update(
where={
"id": search.search_id,
},
data={
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
},
)
return search.search_id
else:
# Create new search
new_search = await prisma.models.BuilderSearchHistory.prisma().create(
data={
"userId": user_id,
"searchQuery": search.search_query or "",
"filter": search.filter or [], # type: ignore
"byCreator": search.by_creator or [],
}
)
return new_search.id
total = 0
skip = (page - 1) * page_size async def get_recent_searches(user_id: str, limit: int = 5) -> list[SearchEntry]:
take = page_size """
Get the user's most recent search requests.
"""
searches = await prisma.models.BuilderSearchHistory.prisma().find_many(
where={
"userId": user_id,
},
order={
"updatedAt": "desc",
},
take=limit,
)
return [
SearchEntry(
search_query=s.searchQuery,
filter=s.filter, # type: ignore
by_creator=s.byCreator,
search_id=s.id,
)
for s in searches
]
async def get_sorted_search_results(
*,
user_id: str,
search_query: str | None,
filters: Sequence[FilterType],
by_creator: Sequence[str] | None = None,
) -> _SearchCacheEntry:
normalized_filters: tuple[FilterType, ...] = tuple(sorted(set(filters or [])))
normalized_creators: tuple[str, ...] = tuple(sorted(set(by_creator or [])))
return await _build_cached_search_results(
user_id=user_id,
search_query=search_query or "",
filters=normalized_filters,
by_creator=normalized_creators,
)
@cached(ttl_seconds=300, shared_cache=True)
async def _build_cached_search_results(
user_id: str,
search_query: str,
filters: tuple[FilterType, ...],
by_creator: tuple[str, ...],
) -> _SearchCacheEntry:
normalized_query = (search_query or "").strip().lower()
include_blocks = "blocks" in filters
include_integrations = "integrations" in filters
include_library_agents = "my_agents" in filters
include_marketplace_agents = "marketplace_agents" in filters
scored_items: list[_ScoredItem] = []
total_items: dict[FilterType, int] = {
"blocks": 0,
"integrations": 0,
"marketplace_agents": 0,
"my_agents": 0,
}
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_query or None,
page=1,
page_size=MAX_LIBRARY_AGENT_RESULTS,
)
total_items["my_agents"] = library_response.pagination.total_items
scored_items.extend(
_build_library_items(
agents=library_response.agents,
normalized_query=normalized_query,
)
)
if include_marketplace_agents:
marketplace_response = await store_db.get_store_agents(
creators=list(by_creator) or None,
search_query=search_query or None,
page=1,
page_size=MAX_MARKETPLACE_AGENT_RESULTS,
)
total_items["marketplace_agents"] = marketplace_response.pagination.total_items
scored_items.extend(
_build_marketplace_items(
agents=marketplace_response.agents,
normalized_query=normalized_query,
)
)
sorted_items = sorted(
scored_items,
key=lambda entry: (-entry.score, entry.sort_key, entry.filter_type),
)
return _SearchCacheEntry(
items=[entry.item for entry in sorted_items],
total_items=total_items,
)
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
results: list[_ScoredItem] = []
block_count = 0 block_count = 0
integration_count = 0 integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
for block_type in load_all_blocks().values(): for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type() block: AnyBlockSchema = block_type()
# Skip disabled blocks
if block.disabled: if block.disabled:
continue continue
# Skip blocks that don't match the query
if ( block_info = block.get_info()
query not in block.name.lower()
and query not in block.description.lower()
and not _matches_llm_model(block.input_schema, query)
):
continue
keep = False
credentials = list(block.input_schema.get_credentials_fields().values()) credentials = list(block.input_schema.get_credentials_fields().values())
if include_integrations and len(credentials) > 0: is_integration = len(credentials) > 0
keep = True
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1 integration_count += 1
if include_blocks and len(credentials) == 0: else:
keep = True
block_count += 1 block_count += 1
if not keep: results.append(
continue _ScoredItem(
item=block_info,
total += 1 filter_type=filter_type,
if skip > 0: score=score,
skip -= 1 sort_key=_get_item_name(block_info),
continue
if take > 0:
take -= 1
blocks.append(block)
return SearchBlocksResponse(
blocks=BlockResponse(
blocks=[b.get_info() for b in blocks],
pagination=Pagination(
total_items=total,
total_pages=(total + page_size - 1) // page_size,
current_page=page,
page_size=page_size,
),
),
total_block_count=block_count,
total_integration_count=integration_count,
) )
)
return results, block_count, integration_count
def _build_library_items(
*,
agents: list[library_model.LibraryAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_library_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
continue
results.append(
_ScoredItem(
item=agent,
filter_type="my_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
def _build_marketplace_items(
*,
agents: list[store_model.StoreAgent],
normalized_query: str,
) -> list[_ScoredItem]:
results: list[_ScoredItem] = []
for agent in agents:
score = _score_store_agent(agent, normalized_query)
if not _should_include_item(score, normalized_query):
continue
results.append(
_ScoredItem(
item=agent,
filter_type="marketplace_agents",
score=score,
sort_key=_get_item_name(agent),
)
)
return results
def get_providers( def get_providers(
@@ -251,16 +450,12 @@ async def get_counts(user_id: str) -> CountResponse:
) )
@cached(ttl_seconds=3600)
async def _get_static_counts(): async def _get_static_counts():
""" """
Get counts of blocks, integrations, and marketplace agents. Get counts of blocks, integrations, and marketplace agents.
This is cached to avoid unnecessary database queries and calculations. This is cached to avoid unnecessary database queries and calculations.
Can't use functools.cache here because the function is async.
""" """
global _static_counts_cache
if _static_counts_cache is not None:
return _static_counts_cache
all_blocks = 0 all_blocks = 0
input_blocks = 0 input_blocks = 0
action_blocks = 0 action_blocks = 0
@@ -287,7 +482,7 @@ async def _get_static_counts():
marketplace_agents = await prisma.models.StoreAgent.prisma().count() marketplace_agents = await prisma.models.StoreAgent.prisma().count()
_static_counts_cache = { return {
"all_blocks": all_blocks, "all_blocks": all_blocks,
"input_blocks": input_blocks, "input_blocks": input_blocks,
"action_blocks": action_blocks, "action_blocks": action_blocks,
@@ -296,8 +491,6 @@ async def _get_static_counts():
"marketplace_agents": marketplace_agents, "marketplace_agents": marketplace_agents,
} }
return _static_counts_cache
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool: def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values(): for field in schema_cls.model_fields.values():
@@ -308,6 +501,123 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
return False return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.name.lower()
description = (agent.description or "").lower()
instructions = (agent.instructions or "").lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(instructions, normalized_query, 15, 6)
score += _score_additional_field(
agent.creator_name.lower(), normalized_query, 10, 5
)
return score
def _score_store_agent(
agent: store_model.StoreAgent,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = agent.agent_name.lower()
description = agent.description.lower()
sub_heading = agent.sub_heading.lower()
score = _score_primary_fields(name, description, normalized_query)
score += _score_additional_field(sub_heading, normalized_query, 12, 6)
score += _score_additional_field(agent.creator.lower(), normalized_query, 10, 5)
return score
def _score_primary_fields(name: str, description: str, query: str) -> float:
score = 0.0
if name == query:
score += 120
elif name.startswith(query):
score += 90
elif query in name:
score += 60
score += SequenceMatcher(None, name, query).ratio() * 50
if description:
if query in description:
score += 30
score += SequenceMatcher(None, description, query).ratio() * 25
return score
def _score_additional_field(
value: str,
query: str,
contains_weight: float,
similarity_weight: float,
) -> float:
if not value or not query:
return 0.0
score = 0.0
if query in value:
score += contains_weight
score += SequenceMatcher(None, value, query).ratio() * similarity_weight
return score
def _should_include_item(score: float, normalized_query: str) -> bool:
if not normalized_query:
return True
return score >= MIN_SCORE_FOR_FILTERED_RESULTS
def _get_item_name(item: SearchResultItem) -> str:
if isinstance(item, BlockInfo):
return item.name.lower()
if isinstance(item, library_model.LibraryAgent):
return item.name.lower()
return item.agent_name.lower()
@cached(ttl_seconds=3600) @cached(ttl_seconds=3600)
def _get_all_providers() -> dict[ProviderName, Provider]: def _get_all_providers() -> dict[ProviderName, Provider]:
providers: dict[ProviderName, Provider] = {} providers: dict[ProviderName, Provider] = {}
@@ -329,13 +639,9 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers return providers
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]: async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
global _suggested_blocks suggested_blocks = []
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
return _suggested_blocks[:count]
_suggested_blocks = []
# Sum the number of executions for each block type # Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query # Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp # Calculate the cutoff timestamp
@@ -376,7 +682,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
# Sort blocks by execution count # Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True) blocks.sort(key=lambda x: x[1], reverse=True)
_suggested_blocks = [block[0] for block in blocks] suggested_blocks = [block[0] for block in blocks]
# Return the top blocks # Return the top blocks
return _suggested_blocks[:count] return suggested_blocks[:count]

View File

@@ -18,10 +18,17 @@ FilterType = Literal[
BlockType = Literal["all", "input", "action", "output"] BlockType = Literal["all", "input", "action", "output"]
class SearchEntry(BaseModel):
search_query: str | None = None
filter: list[FilterType] | None = None
by_creator: list[str] | None = None
search_id: str | None = None
# Suggestions # Suggestions
class SuggestionsResponse(BaseModel): class SuggestionsResponse(BaseModel):
otto_suggestions: list[str] otto_suggestions: list[str]
recent_searches: list[str] recent_searches: list[SearchEntry]
providers: list[ProviderName] providers: list[ProviderName]
top_blocks: list[BlockInfo] top_blocks: list[BlockInfo]
@@ -32,7 +39,7 @@ class BlockCategoryResponse(BaseModel):
total_blocks: int total_blocks: int
blocks: list[BlockInfo] blocks: list[BlockInfo]
model_config = {"use_enum_values": False} # <== use enum names like "AI" model_config = {"use_enum_values": False} # Use enum names like "AI"
# Input/Action/Output and see all for block categories # Input/Action/Output and see all for block categories
@@ -53,17 +60,11 @@ class ProviderResponse(BaseModel):
pagination: Pagination pagination: Pagination
class SearchBlocksResponse(BaseModel):
blocks: BlockResponse
total_block_count: int
total_integration_count: int
class SearchResponse(BaseModel): class SearchResponse(BaseModel):
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent] items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
search_id: str
total_items: dict[FilterType, int] total_items: dict[FilterType, int]
page: int pagination: Pagination
more_pages: bool
class CountResponse(BaseModel): class CountResponse(BaseModel):

View File

@@ -6,10 +6,6 @@ from autogpt_libs.auth.dependencies import get_user_id, requires_user
import backend.server.v2.builder.db as builder_db import backend.server.v2.builder.db as builder_db
import backend.server.v2.builder.model as builder_model import backend.server.v2.builder.model as builder_model
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.db as store_db
import backend.server.v2.store.model as store_model
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util.models import Pagination from backend.util.models import Pagination
@@ -45,7 +41,9 @@ def sanitize_query(query: str | None) -> str | None:
summary="Get Builder suggestions", summary="Get Builder suggestions",
response_model=builder_model.SuggestionsResponse, response_model=builder_model.SuggestionsResponse,
) )
async def get_suggestions() -> builder_model.SuggestionsResponse: async def get_suggestions(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> builder_model.SuggestionsResponse:
""" """
Get all suggestions for the Blocks Menu. Get all suggestions for the Blocks Menu.
""" """
@@ -55,11 +53,7 @@ async def get_suggestions() -> builder_model.SuggestionsResponse:
"Help me create a list", "Help me create a list",
"Help me feed my data to Google Maps", "Help me feed my data to Google Maps",
], ],
recent_searches=[ recent_searches=await builder_db.get_recent_searches(user_id),
"image generation",
"deepfake",
"competitor analysis",
],
providers=[ providers=[
ProviderName.TWITTER, ProviderName.TWITTER,
ProviderName.GITHUB, ProviderName.GITHUB,
@@ -147,7 +141,6 @@ async def get_providers(
) )
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
@router.get( @router.get(
"/search", "/search",
summary="Builder search", summary="Builder search",
@@ -157,7 +150,7 @@ async def get_providers(
async def search( async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)], user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None, search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[str] | None, fastapi.Query()] = None, filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None, search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None, by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1, page: Annotated[int, fastapi.Query()] = 1,
@@ -176,69 +169,43 @@ async def search(
] ]
search_query = sanitize_query(search_query) search_query = sanitize_query(search_query)
# Blocks&Integrations # Get all possible results
blocks = builder_model.SearchBlocksResponse( cached_results = await builder_db.get_sorted_search_results(
blocks=builder_model.BlockResponse(
blocks=[],
pagination=Pagination.empty(),
),
total_block_count=0,
total_integration_count=0,
)
if "blocks" in filter or "integrations" in filter:
blocks = builder_db.search_blocks(
include_blocks="blocks" in filter,
include_integrations="integrations" in filter,
query=search_query or "",
page=page,
page_size=page_size,
)
# Library Agents
my_agents = library_model.LibraryAgentResponse(
agents=[],
pagination=Pagination.empty(),
)
if "my_agents" in filter:
my_agents = await library_db.list_library_agents(
user_id=user_id, user_id=user_id,
search_term=search_query,
page=page,
page_size=page_size,
)
# Marketplace Agents
marketplace_agents = store_model.StoreAgentsResponse(
agents=[],
pagination=Pagination.empty(),
)
if "marketplace_agents" in filter:
marketplace_agents = await store_db.get_store_agents(
creators=by_creator,
search_query=search_query, search_query=search_query,
page=page, filters=filter,
by_creator=by_creator,
)
# Paginate results
total_combined_items = len(cached_results.items)
pagination = Pagination(
total_items=total_combined_items,
total_pages=(total_combined_items + page_size - 1) // page_size,
current_page=page,
page_size=page_size, page_size=page_size,
) )
more_pages = False start_idx = (page - 1) * page_size
if ( end_idx = start_idx + page_size
blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages paginated_items = cached_results.items[start_idx:end_idx]
or my_agents.pagination.current_page < my_agents.pagination.total_pages
or marketplace_agents.pagination.current_page # Update the search entry by id
< marketplace_agents.pagination.total_pages search_id = await builder_db.update_search(
): user_id,
more_pages = True builder_model.SearchEntry(
search_query=search_query,
filter=filter,
by_creator=by_creator,
search_id=search_id,
),
)
return builder_model.SearchResponse( return builder_model.SearchResponse(
items=blocks.blocks.blocks + my_agents.agents + marketplace_agents.agents, items=paginated_items,
total_items={ search_id=search_id,
"blocks": blocks.total_block_count, total_items=cached_results.total_items,
"integrations": blocks.total_integration_count, pagination=pagination,
"marketplace_agents": marketplace_agents.pagination.total_items,
"my_agents": my_agents.pagination.total_items,
},
page=page,
more_pages=more_pages,
) )

View File

@@ -3,7 +3,6 @@ from datetime import UTC, datetime
from os import getenv from os import getenv
import pytest import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr from pydantic import SecretStr
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
@@ -50,13 +49,13 @@ async def setup_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data=ProfileCreateInput( data={
userId=user.id, "userId": user.id,
username=username, "username": username,
name=f"Test User {username}", "name": f"Test User {username}",
description="Test user profile", "description": "Test user profile",
links=[], # Required field - empty array for test profiles "links": [], # Required field - empty array for test profiles
) }
) )
# 2. Create a test graph with agent input -> agent output # 2. Create a test graph with agent input -> agent output
@@ -173,13 +172,13 @@ async def setup_llm_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data=ProfileCreateInput( data={
userId=user.id, "userId": user.id,
username=username, "username": username,
name=f"Test User {username}", "name": f"Test User {username}",
description="Test user profile for LLM tests", "description": "Test user profile for LLM tests",
links=[], # Required field - empty array for test profiles "links": [], # Required field - empty array for test profiles
) }
) )
# 2. Create test OpenAI credentials for the user # 2. Create test OpenAI credentials for the user
@@ -333,13 +332,13 @@ async def setup_firecrawl_test_data():
# 1b. Create a profile with username for the user (required for store agent lookup) # 1b. Create a profile with username for the user (required for store agent lookup)
username = user.email.split("@")[0] username = user.email.split("@")[0]
await prisma.profile.create( await prisma.profile.create(
data=ProfileCreateInput( data={
userId=user.id, "userId": user.id,
username=username, "username": username,
name=f"Test User {username}", "name": f"Test User {username}",
description="Test user profile for Firecrawl tests", "description": "Test user profile for Firecrawl tests",
links=[], # Required field - empty array for test profiles "links": [], # Required field - empty array for test profiles
) }
) )
# NOTE: We deliberately do NOT create Firecrawl credentials for this user # NOTE: We deliberately do NOT create Firecrawl credentials for this user

View File

@@ -802,16 +802,18 @@ async def add_store_agent_to_library(
# Create LibraryAgent entry # Create LibraryAgent entry
added_agent = await prisma.models.LibraryAgent.prisma().create( added_agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput( data={
User={"connect": {"id": user_id}}, "User": {"connect": {"id": user_id}},
AgentGraph={ "AgentGraph": {
"connect": { "connect": {
"graphVersionId": {"id": graph.id, "version": graph.version} "graphVersionId": {"id": graph.id, "version": graph.version}
} }
}, },
isCreatedByUser=False, "isCreatedByUser": False,
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()), "settings": SafeJson(
_initialize_graph_settings(graph_model).model_dump()
), ),
},
include=library_agent_include( include=library_agent_include(
user_id, include_nodes=False, include_executions=False user_id, include_nodes=False, include_executions=False
), ),

View File

@@ -248,9 +248,7 @@ async def log_search_term(search_query: str):
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
try: try:
await prisma.models.SearchTerms.prisma().create( await prisma.models.SearchTerms.prisma().create(
data=prisma.types.SearchTermsCreateInput( data={"searchTerm": search_query, "createdDate": date}
searchTerm=search_query, createdDate=date
)
) )
except Exception as e: except Exception as e:
# Fail silently here so that logging search terms doesn't break the app # Fail silently here so that logging search terms doesn't break the app
@@ -1432,10 +1430,13 @@ async def _approve_sub_agent(
# Create new version if no matching version found # Create new version if no matching version found
next_version = max((v.version for v in listing.Versions or []), default=0) + 1 next_version = max((v.version for v in listing.Versions or []), default=0) + 1
sub_agent_data = _create_sub_agent_version_data(sub_graph, heading, main_agent_name) await prisma.models.StoreListingVersion.prisma(tx).create(
sub_agent_data["version"] = next_version data={
sub_agent_data["storeListingId"] = listing.id **_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
await prisma.models.StoreListingVersion.prisma(tx).create(data=sub_agent_data) "version": next_version,
"storeListingId": listing.id,
}
)
await prisma.models.StoreListing.prisma(tx).update( await prisma.models.StoreListing.prisma(tx).update(
where={"id": listing.id}, data={"hasApprovedVersion": True} where={"id": listing.id}, data={"hasApprovedVersion": True}
) )

View File

@@ -5,6 +5,13 @@ from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
# ---------------------------------------------------------------------------#
# CONSTANTS #
# ---------------------------------------------------------------------------#
# Message prefixes for important system messages that should be protected during compression
MAIN_OBJECTIVE_PREFIX = "[Main Objective Prompt]: "
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
# INTERNAL UTILITIES # # INTERNAL UTILITIES #
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
@@ -63,6 +70,55 @@ def _msg_tokens(msg: dict, enc) -> int:
return WRAPPER + content_tokens + tool_call_tokens return WRAPPER + content_tokens + tool_call_tokens
def _is_tool_message(msg: dict) -> bool:
"""Check if a message contains tool calls or results that should be protected."""
content = msg.get("content")
# Check for Anthropic-style tool messages
if isinstance(content, list) and any(
isinstance(item, dict) and item.get("type") in ("tool_use", "tool_result")
for item in content
):
return True
# Check for OpenAI-style tool calls in the message
if "tool_calls" in msg or msg.get("role") == "tool":
return True
return False
def _is_objective_message(msg: dict) -> bool:
"""Check if a message contains objective/system prompts that should be absolutely protected."""
content = msg.get("content", "")
if isinstance(content, str):
# Protect any message with the main objective prefix
return content.startswith(MAIN_OBJECTIVE_PREFIX)
return False
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
"""
Carefully truncate tool message content while preserving tool structure.
Only truncates tool_result content, leaves tool_use intact.
"""
content = msg.get("content")
if not isinstance(content, list):
return
for item in content:
# Only process tool_result items, leave tool_use blocks completely intact
if not (isinstance(item, dict) and item.get("type") == "tool_result"):
continue
result_content = item.get("content", "")
if (
isinstance(result_content, str)
and _tok_len(result_content, enc) > max_tokens
):
item["content"] = _truncate_middle_tokens(result_content, enc, max_tokens)
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str: def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
""" """
Return *text* shortened to ≈max_tok tokens by keeping the head & tail Return *text* shortened to ≈max_tok tokens by keeping the head & tail
@@ -140,13 +196,21 @@ def compress_prompt(
return sum(_msg_tokens(m, enc) for m in msgs) return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens() original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens: if original_token_count + reserve <= target_tokens:
return msgs return msgs
# ---- STEP 0 : normalise content -------------------------------------- # ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent. # Convert non-string payloads to strings so token counting is coherent.
for m in msgs[1:-1]: # keep the first & last intact for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None: if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
# Keep first and last messages intact (unless they're tool messages)
if i == 0 or i == len(msgs) - 1:
continue
# Reasonable 20k-char ceiling prevents pathological blobs # Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":")) content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000: if len(content_str) > 20_000:
@@ -157,34 +221,45 @@ def compress_prompt(
cap = start_cap cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap: while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact for m in msgs[1:-1]: # keep first & last intact
if _tok_len(m.get("content") or "", enc) > cap: if _is_tool_message(m):
m["content"] = _truncate_middle_tokens(m["content"], enc, cap) # For tool messages, only truncate tool result content, preserve structure
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
# Never truncate objective messages - they contain the core task
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2 # tighten the screw cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion ----------------------------------- # ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2: while total_tokens() + reserve > target_tokens and len(msgs) > 2:
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
deletable_indices = []
for i in range(1, len(msgs) - 1): # Skip first and last
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
deletable_indices.append(i)
if not deletable_indices:
break # nothing more we can drop
# Delete from center outward - find the index closest to center
centre = len(msgs) // 2 centre = len(msgs) // 2
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ... to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
order = [centre] + [ del msgs[to_delete]
i
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
for i in pair
]
removed = False
for i in order:
msg = msgs[i]
if "tool_calls" in msg or msg.get("role") == "tool":
continue # protect tool shells
del msgs[i]
removed = True
break
if not removed: # nothing more we can drop
break
# ---- STEP 3 : final safety-net trim on first & last ------------------ # ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap: while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last for idx in (0, -1): # first and last
if _is_tool_message(msgs[idx]):
# For tool messages at first/last position, truncate tool result content only
_truncate_tool_message_content(msgs[idx], enc, cap)
continue
text = msgs[idx].get("content") or "" text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap: if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap) msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)

View File

@@ -1,228 +0,0 @@
"""
Rate Limiting for External API.
Implements sliding window rate limiting using Redis for distributed systems.
"""
import logging
import time
from dataclasses import dataclass
from typing import Optional
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
@dataclass
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
remaining: int
reset_at: float
retry_after: Optional[float] = None
class RateLimiter:
"""
Redis-based sliding window rate limiter.
Supports multiple limit tiers (per-minute, per-hour, per-day).
"""
def __init__(self, prefix: str = "ratelimit"):
self.prefix = prefix
def _make_key(self, identifier: str, window: str) -> str:
"""Create a Redis key for the rate limit counter."""
return f"{self.prefix}:{identifier}:{window}"
async def check_and_increment(
self,
identifier: str,
limits: dict[str, tuple[int, int]], # window_name -> (limit, window_seconds)
) -> RateLimitResult:
"""
Check rate limits and increment counters if allowed.
Uses atomic increment-first approach to prevent race conditions:
1. Increment all counters atomically
2. Check if any limit exceeded
3. If exceeded, decrement and return rate limit error
Args:
identifier: Unique identifier (e.g., client_id, client_id:user_id)
limits: Dictionary of limit configurations
e.g., {"minute": (60, 60), "hour": (1000, 3600)}
Returns:
RateLimitResult with allowed status and remaining quota
"""
if not limits:
# No limits configured, allow request
return RateLimitResult(
allowed=True,
remaining=999999,
reset_at=time.time() + 60,
)
redis = await get_redis_async()
current_time = time.time()
# Increment all counters atomically first
incremented_keys: list[tuple[str, int, int, int]] = (
[]
) # (key, new_count, limit, window_seconds)
for window_name, (limit, window_seconds) in limits.items():
key = self._make_key(identifier, window_name)
# Atomic increment
new_count = await redis.incr(key)
# Set expiry if this is a new key
if new_count == 1:
await redis.expire(key, window_seconds)
incremented_keys.append((key, new_count, limit, window_seconds))
# Check if any limit exceeded
for key, new_count, limit, window_seconds in incremented_keys:
if new_count > limit:
# Rate limit exceeded - decrement all counters we just incremented
for decr_key, _, _, _ in incremented_keys:
await redis.decr(decr_key)
ttl = await redis.ttl(key)
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
return RateLimitResult(
allowed=False,
remaining=0,
reset_at=reset_at,
retry_after=ttl if ttl > 0 else window_seconds,
)
# All limits passed
min_remaining = float("inf")
earliest_reset = current_time
for key, new_count, limit, window_seconds in incremented_keys:
remaining = max(0, limit - new_count)
min_remaining = min(min_remaining, remaining)
ttl = await redis.ttl(key)
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
earliest_reset = max(earliest_reset, reset_at)
return RateLimitResult(
allowed=True,
remaining=int(min_remaining),
reset_at=earliest_reset,
)
async def get_remaining(
self,
identifier: str,
limits: dict[str, tuple[int, int]],
) -> dict[str, int]:
"""
Get remaining quota for all windows without incrementing.
Args:
identifier: Unique identifier
limits: Dictionary of limit configurations
Returns:
Dictionary of remaining quota per window
"""
redis = await get_redis_async()
remaining = {}
for window_name, (limit, _) in limits.items():
key = self._make_key(identifier, window_name)
count = await redis.get(key)
current_count = int(count) if count else 0
remaining[window_name] = max(0, limit - current_count)
return remaining
async def reset(self, identifier: str, window: Optional[str] = None) -> None:
"""
Reset rate limit counters.
Args:
identifier: Unique identifier
window: Optional specific window to reset (resets all if None)
"""
redis = await get_redis_async()
if window:
key = self._make_key(identifier, window)
await redis.delete(key)
else:
# Delete known window keys instead of scanning
# This avoids potentially slow scan operations with many keys
known_windows = ["minute", "hour", "day"]
keys_to_delete = [self._make_key(identifier, w) for w in known_windows]
# Delete all in one call (Redis handles non-existent keys gracefully)
if keys_to_delete:
await redis.delete(*keys_to_delete)
# Default rate limits for different endpoints
DEFAULT_RATE_LIMITS = {
# OAuth endpoints
"oauth_authorize": {"minute": (30, 60)}, # 30/min per IP
"oauth_token": {"minute": (20, 60)}, # 20/min per client
"oauth_consent": {"minute": (20, 60)}, # 20/min per IP for consent submission
# External API endpoints
"api_execute": {
"minute": (10, 60),
"hour": (100, 3600),
}, # 10/min, 100/hour per client+user
"api_read": {
"minute": (60, 60),
"hour": (1000, 3600),
}, # 60/min, 1000/hour per client+user
}
# Module-level singleton
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""Get the singleton rate limiter instance."""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_rate_limit(
identifier: str,
limit_type: str,
) -> RateLimitResult:
"""
Convenience function to check rate limits.
Args:
identifier: Unique identifier for the rate limit
limit_type: Type of limit from DEFAULT_RATE_LIMITS
Returns:
RateLimitResult
"""
limits = DEFAULT_RATE_LIMITS.get(limit_type)
if not limits:
# No rate limit configured, allow
return RateLimitResult(
allowed=True,
remaining=999999,
reset_at=time.time() + 60,
)
rate_limiter = get_rate_limiter()
return await rate_limiter.check_and_increment(identifier, limits)

View File

@@ -651,23 +651,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key") ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key") ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
# OAuth Provider JWT keys
oauth_jwt_private_key: str = Field(
default="",
description="RSA private key for signing OAuth tokens (PEM format). "
"If not set, a development key will be auto-generated.",
)
oauth_jwt_public_key: str = Field(
default="",
description="RSA public key for verifying OAuth tokens (PEM format). "
"If not set, derived from private key.",
)
oauth_jwt_key_id: str = Field(
default="autogpt-oauth-key-1",
description="Key ID (kid) for JWKS. Used to identify the signing key.",
)
# Add more secret fields as needed # Add more secret fields as needed
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",

View File

@@ -1,43 +0,0 @@
"""
Time utilities for the backend.
Common datetime operations used across the codebase.
"""
from datetime import datetime, timedelta, timezone
def expiration_datetime(seconds: int) -> datetime:
"""
Calculate an expiration datetime from now.
Args:
seconds: Number of seconds until expiration
Returns:
Datetime when the item will expire (UTC)
"""
return datetime.now(timezone.utc) + timedelta(seconds=seconds)
def is_expired(dt: datetime) -> bool:
"""
Check if a datetime has passed.
Args:
dt: The datetime to check (should be timezone-aware)
Returns:
True if the datetime is in the past
"""
return dt < datetime.now(timezone.utc)
def utc_now() -> datetime:
"""
Get the current UTC time.
Returns:
Current datetime in UTC
"""
return datetime.now(timezone.utc)

View File

@@ -1,46 +0,0 @@
"""
URL and domain validation utilities.
Common URL validation operations used across the codebase.
"""
def matches_domain_pattern(hostname: str, domain_pattern: str) -> bool:
"""
Check if a hostname matches a domain pattern.
Supports wildcard patterns (*.example.com) which match:
- The base domain (example.com)
- Any subdomain (sub.example.com, deep.sub.example.com)
Args:
hostname: The hostname to check (e.g., "api.example.com")
domain_pattern: The pattern to match against (e.g., "*.example.com" or "example.com")
Returns:
True if the hostname matches the pattern
"""
hostname = hostname.lower()
domain_pattern = domain_pattern.lower()
if domain_pattern.startswith("*."):
# Wildcard domain - matches base and any subdomains
base_domain = domain_pattern[2:]
return hostname == base_domain or hostname.endswith("." + base_domain)
# Exact match
return hostname == domain_pattern
def hostname_matches_any_domain(hostname: str, allowed_domains: list[str]) -> bool:
"""
Check if a hostname matches any of the allowed domain patterns.
Args:
hostname: The hostname to check
allowed_domains: List of allowed domain patterns (supports wildcards)
Returns:
True if the hostname matches any pattern
"""
return any(matches_domain_pattern(hostname, domain) for domain in allowed_domains)

View File

@@ -1,249 +0,0 @@
-- CreateEnum
CREATE TYPE "OAuthClientStatus" AS ENUM ('ACTIVE', 'SUSPENDED');
-- CreateEnum
CREATE TYPE "CredentialGrantPermission" AS ENUM ('USE', 'DELETE');
-- CreateTable
CREATE TABLE "OAuthClient" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"clientId" TEXT NOT NULL,
"clientSecretHash" TEXT,
"clientSecretSalt" TEXT,
"clientType" TEXT NOT NULL,
"name" TEXT NOT NULL,
"description" TEXT,
"logoUrl" TEXT,
"homepageUrl" TEXT,
"privacyPolicyUrl" TEXT,
"termsOfServiceUrl" TEXT,
"redirectUris" TEXT[],
"allowedScopes" TEXT[],
"webhookDomains" TEXT[],
"requirePkce" BOOLEAN NOT NULL DEFAULT true,
"tokenLifetimeSecs" INTEGER NOT NULL DEFAULT 3600,
"refreshTokenLifetimeSecs" INTEGER NOT NULL DEFAULT 2592000,
"status" "OAuthClientStatus" NOT NULL DEFAULT 'ACTIVE',
"ownerId" TEXT NOT NULL,
CONSTRAINT "OAuthClient_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuthorization" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"revokedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAuthorization_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuthorizationCode" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"codeHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"redirectUri" TEXT NOT NULL,
"scopes" TEXT[],
"nonce" TEXT,
"codeChallenge" TEXT NOT NULL,
"codeChallengeMethod" TEXT NOT NULL DEFAULT 'S256',
"expiresAt" TIMESTAMP(3) NOT NULL,
"usedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAuthorizationCode_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAccessToken" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"tokenHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"expiresAt" TIMESTAMP(3) NOT NULL,
"revokedAt" TIMESTAMP(3),
"lastUsedAt" TIMESTAMP(3),
CONSTRAINT "OAuthAccessToken_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthRefreshToken" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"tokenHash" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"scopes" TEXT[],
"expiresAt" TIMESTAMP(3) NOT NULL,
"revokedAt" TIMESTAMP(3),
CONSTRAINT "OAuthRefreshToken_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "CredentialGrant" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"userId" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"credentialId" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"grantedScopes" TEXT[],
"permissions" "CredentialGrantPermission"[],
"expiresAt" TIMESTAMP(3),
"revokedAt" TIMESTAMP(3),
"lastUsedAt" TIMESTAMP(3),
CONSTRAINT "CredentialGrant_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "OAuthAuditLog" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"eventType" TEXT NOT NULL,
"userId" TEXT,
"clientId" TEXT,
"grantId" TEXT,
"ipAddress" TEXT,
"userAgent" TEXT,
"details" JSONB NOT NULL DEFAULT '{}',
CONSTRAINT "OAuthAuditLog_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ExecutionWebhook" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"executionId" TEXT NOT NULL,
"webhookUrl" TEXT NOT NULL,
"clientId" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"secret" TEXT,
CONSTRAINT "ExecutionWebhook_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "OAuthClient_clientId_key" ON "OAuthClient"("clientId");
-- CreateIndex
CREATE INDEX "OAuthClient_clientId_idx" ON "OAuthClient"("clientId");
-- CreateIndex
CREATE INDEX "OAuthClient_ownerId_idx" ON "OAuthClient"("ownerId");
-- CreateIndex
CREATE INDEX "OAuthClient_status_idx" ON "OAuthClient"("status");
-- CreateIndex
CREATE INDEX "OAuthAuthorization_userId_idx" ON "OAuthAuthorization"("userId");
-- CreateIndex
CREATE INDEX "OAuthAuthorization_clientId_idx" ON "OAuthAuthorization"("clientId");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAuthorization_userId_clientId_key" ON "OAuthAuthorization"("userId", "clientId");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAuthorizationCode_codeHash_key" ON "OAuthAuthorizationCode"("codeHash");
-- CreateIndex
CREATE INDEX "OAuthAuthorizationCode_codeHash_idx" ON "OAuthAuthorizationCode"("codeHash");
-- CreateIndex
CREATE INDEX "OAuthAuthorizationCode_expiresAt_idx" ON "OAuthAuthorizationCode"("expiresAt");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthAccessToken_tokenHash_key" ON "OAuthAccessToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_tokenHash_idx" ON "OAuthAccessToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_userId_clientId_idx" ON "OAuthAccessToken"("userId", "clientId");
-- CreateIndex
CREATE INDEX "OAuthAccessToken_expiresAt_idx" ON "OAuthAccessToken"("expiresAt");
-- CreateIndex
CREATE UNIQUE INDEX "OAuthRefreshToken_tokenHash_key" ON "OAuthRefreshToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthRefreshToken_tokenHash_idx" ON "OAuthRefreshToken"("tokenHash");
-- CreateIndex
CREATE INDEX "OAuthRefreshToken_expiresAt_idx" ON "OAuthRefreshToken"("expiresAt");
-- CreateIndex
CREATE INDEX "CredentialGrant_userId_clientId_idx" ON "CredentialGrant"("userId", "clientId");
-- CreateIndex
CREATE INDEX "CredentialGrant_clientId_idx" ON "CredentialGrant"("clientId");
-- CreateIndex
CREATE UNIQUE INDEX "CredentialGrant_userId_clientId_credentialId_key" ON "CredentialGrant"("userId", "clientId", "credentialId");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_createdAt_idx" ON "OAuthAuditLog"("createdAt");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_eventType_idx" ON "OAuthAuditLog"("eventType");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_userId_idx" ON "OAuthAuditLog"("userId");
-- CreateIndex
CREATE INDEX "OAuthAuditLog_clientId_idx" ON "OAuthAuditLog"("clientId");
-- CreateIndex
CREATE INDEX "ExecutionWebhook_executionId_idx" ON "ExecutionWebhook"("executionId");
-- CreateIndex
CREATE INDEX "ExecutionWebhook_clientId_idx" ON "ExecutionWebhook"("clientId");
-- AddForeignKey
ALTER TABLE "OAuthClient" ADD CONSTRAINT "OAuthClient_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -1,2 +0,0 @@
-- AlterTable
ALTER TABLE "platform"."OAuthClient" ADD COLUMN "webhookSecret" TEXT;

View File

@@ -0,0 +1,15 @@
-- Create BuilderSearchHistory table
CREATE TABLE "BuilderSearchHistory" (
"id" TEXT NOT NULL,
"userId" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"searchQuery" TEXT NOT NULL,
"filter" TEXT[] DEFAULT ARRAY[]::TEXT[],
"byCreator" TEXT[] DEFAULT ARRAY[]::TEXT[],
CONSTRAINT "BuilderSearchHistory_pkey" PRIMARY KEY ("id")
);
-- Define User foreign relation
ALTER TABLE "BuilderSearchHistory" ADD CONSTRAINT "BuilderSearchHistory_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -53,6 +53,7 @@ model User {
Profile Profile[] Profile Profile[]
UserOnboarding UserOnboarding? UserOnboarding UserOnboarding?
BuilderSearchHistory BuilderSearchHistory[]
StoreListings StoreListing[] StoreListings StoreListing[]
StoreListingReviews StoreListingReview[] StoreListingReviews StoreListingReview[]
StoreVersionsReviewed StoreListingVersion[] StoreVersionsReviewed StoreListingVersion[]
@@ -60,14 +61,6 @@ model User {
IntegrationWebhooks IntegrationWebhook[] IntegrationWebhooks IntegrationWebhook[]
NotificationBatches UserNotificationBatch[] NotificationBatches UserNotificationBatch[]
PendingHumanReviews PendingHumanReview[] PendingHumanReviews PendingHumanReview[]
// OAuth Provider relations
OAuthClientsOwned OAuthClient[] @relation("OAuthClientOwner")
OAuthAuthorizations OAuthAuthorization[]
OAuthAuthorizationCodes OAuthAuthorizationCode[]
OAuthAccessTokens OAuthAccessToken[]
OAuthRefreshTokens OAuthRefreshToken[]
CredentialGrants CredentialGrant[]
} }
enum OnboardingStep { enum OnboardingStep {
@@ -122,6 +115,19 @@ model UserOnboarding {
User User @relation(fields: [userId], references: [id], onDelete: Cascade) User User @relation(fields: [userId], references: [id], onDelete: Cascade)
} }
model BuilderSearchHistory {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
searchQuery String
filter String[] @default([])
byCreator String[] @default([])
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
}
// This model describes the Agent Graph/Flow (Multi Agent System). // This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph { model AgentGraph {
id String @default(uuid()) id String @default(uuid())
@@ -969,226 +975,3 @@ enum APIKeyStatus {
REVOKED REVOKED
SUSPENDED SUSPENDED
} }
// ============================================================
// OAuth Provider & Credential Broker Models
// ============================================================
enum OAuthClientStatus {
ACTIVE
SUSPENDED
}
enum CredentialGrantPermission {
USE // Can use credential for agent execution
DELETE // Can delete the credential
}
// OAuth Client - Registered external applications
model OAuthClient {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Client identification
clientId String @unique // Public identifier (e.g., "app_abc123")
clientSecretHash String? // Hashed (null for public clients)
clientSecretSalt String?
clientType String // "public" or "confidential"
// Metadata (shown on consent screen)
name String
description String?
logoUrl String?
homepageUrl String?
privacyPolicyUrl String?
termsOfServiceUrl String?
// Configuration
redirectUris String[]
allowedScopes String[]
webhookDomains String[] // For webhook URL validation
webhookSecret String? // Secret for HMAC signing webhooks
// Security
requirePkce Boolean @default(true)
tokenLifetimeSecs Int @default(3600)
refreshTokenLifetimeSecs Int @default(2592000) // 30 days
// Status
status OAuthClientStatus @default(ACTIVE)
// Owner
ownerId String
Owner User @relation("OAuthClientOwner", fields: [ownerId], references: [id], onDelete: Cascade)
// Relations
Authorizations OAuthAuthorization[]
AuthorizationCodes OAuthAuthorizationCode[]
AccessTokens OAuthAccessToken[]
RefreshTokens OAuthRefreshToken[]
CredentialGrants CredentialGrant[]
@@index([clientId])
@@index([ownerId])
@@index([status])
}
// OAuth Authorization - User consent record
model OAuthAuthorization {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
revokedAt DateTime?
@@unique([userId, clientId])
@@index([userId])
@@index([clientId])
}
// OAuth Authorization Code - Short-lived, single-use
model OAuthAuthorizationCode {
id String @id @default(uuid())
createdAt DateTime @default(now())
codeHash String @unique
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
redirectUri String
scopes String[]
nonce String? // OIDC nonce
// PKCE
codeChallenge String
codeChallengeMethod String @default("S256")
expiresAt DateTime // 10 minutes
usedAt DateTime?
@@index([codeHash])
@@index([expiresAt])
}
// OAuth Access Token
model OAuthAccessToken {
id String @id @default(uuid())
createdAt DateTime @default(now())
tokenHash String @unique // SHA256 of token
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
expiresAt DateTime
revokedAt DateTime?
lastUsedAt DateTime?
@@index([tokenHash])
@@index([userId, clientId])
@@index([expiresAt])
}
// OAuth Refresh Token
model OAuthRefreshToken {
id String @id @default(uuid())
createdAt DateTime @default(now())
tokenHash String @unique
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
scopes String[]
expiresAt DateTime
revokedAt DateTime?
@@index([tokenHash])
@@index([expiresAt])
}
// Credential Grant - Links external app to user's credential with scoped access
model CredentialGrant {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
userId String
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
clientId String
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
credentialId String // Reference to credential in User.integrations
provider String
// Fine-grained integration scopes (e.g., "google:gmail.readonly")
grantedScopes String[]
// Permissions for the credential itself
permissions CredentialGrantPermission[]
expiresAt DateTime?
revokedAt DateTime?
lastUsedAt DateTime?
@@unique([userId, clientId, credentialId])
@@index([userId, clientId])
@@index([clientId])
}
// OAuth Audit Log
model OAuthAuditLog {
id String @id @default(uuid())
createdAt DateTime @default(now())
eventType String // e.g., "token.issued", "grant.created"
userId String?
clientId String?
grantId String?
ipAddress String?
userAgent String?
details Json @default("{}")
@@index([createdAt])
@@index([eventType])
@@index([userId])
@@index([clientId])
}
// Execution Webhook - Webhook registration for external API executions
model ExecutionWebhook {
id String @id @default(uuid())
createdAt DateTime @default(now())
executionId String // The graph execution ID
webhookUrl String // URL to send notifications to
clientId String // The OAuth client database ID
userId String // The user who started the execution
secret String? // Optional webhook secret for HMAC signing
@@index([executionId])
@@index([clientId])
}

View File

@@ -22,7 +22,6 @@ import random
from typing import Any, Dict, List from typing import Any, Dict, List
from faker import Faker from faker import Faker
from prisma.types import AgentBlockCreateInput
from backend.data.api_key import create_api_key from backend.data.api_key import create_api_key
from backend.data.credit import get_user_credit_model from backend.data.credit import get_user_credit_model
@@ -178,12 +177,12 @@ class TestDataCreator:
for block in blocks_to_create: for block in blocks_to_create:
try: try:
await prisma.agentblock.create( await prisma.agentblock.create(
data=AgentBlockCreateInput( data={
id=block.id, "id": block.id,
name=block.name, "name": block.name,
inputSchema="{}", "inputSchema": "{}",
outputSchema="{}", "outputSchema": "{}",
) }
) )
except Exception as e: except Exception as e:
print(f"Error creating block {block.name}: {e}") print(f"Error creating block {block.name}: {e}")

View File

@@ -30,19 +30,13 @@ from prisma.types import (
AgentGraphCreateInput, AgentGraphCreateInput,
AgentNodeCreateInput, AgentNodeCreateInput,
AgentNodeLinkCreateInput, AgentNodeLinkCreateInput,
AgentPresetCreateInput,
AnalyticsDetailsCreateInput, AnalyticsDetailsCreateInput,
AnalyticsMetricsCreateInput, AnalyticsMetricsCreateInput,
APIKeyCreateInput,
CreditTransactionCreateInput, CreditTransactionCreateInput,
IntegrationWebhookCreateInput, IntegrationWebhookCreateInput,
LibraryAgentCreateInput,
ProfileCreateInput, ProfileCreateInput,
StoreListingCreateInput,
StoreListingReviewCreateInput, StoreListingReviewCreateInput,
StoreListingVersionCreateInput,
UserCreateInput, UserCreateInput,
UserOnboardingCreateInput,
) )
faker = Faker() faker = Faker()
@@ -178,14 +172,14 @@ async def main():
for _ in range(num_presets): # Create 1 AgentPreset per user for _ in range(num_presets): # Create 1 AgentPreset per user
graph = random.choice(agent_graphs) graph = random.choice(agent_graphs)
preset = await db.agentpreset.create( preset = await db.agentpreset.create(
data=AgentPresetCreateInput( data={
name=faker.sentence(nb_words=3), "name": faker.sentence(nb_words=3),
description=faker.text(max_nb_chars=200), "description": faker.text(max_nb_chars=200),
userId=user.id, "userId": user.id,
agentGraphId=graph.id, "agentGraphId": graph.id,
agentGraphVersion=graph.version, "agentGraphVersion": graph.version,
isActive=True, "isActive": True,
) }
) )
agent_presets.append(preset) agent_presets.append(preset)
@@ -226,18 +220,18 @@ async def main():
) )
library_agent = await db.libraryagent.create( library_agent = await db.libraryagent.create(
data=LibraryAgentCreateInput( data={
userId=user.id, "userId": user.id,
agentGraphId=graph.id, "agentGraphId": graph.id,
agentGraphVersion=graph.version, "agentGraphVersion": graph.version,
creatorId=creator_profile.id if creator_profile else None, "creatorId": creator_profile.id if creator_profile else None,
imageUrl=get_image() if random.random() < 0.5 else None, "imageUrl": get_image() if random.random() < 0.5 else None,
useGraphIsActiveVersion=random.choice([True, False]), "useGraphIsActiveVersion": random.choice([True, False]),
isFavorite=random.choice([True, False]), "isFavorite": random.choice([True, False]),
isCreatedByUser=random.choice([True, False]), "isCreatedByUser": random.choice([True, False]),
isArchived=random.choice([True, False]), "isArchived": random.choice([True, False]),
isDeleted=random.choice([True, False]), "isDeleted": random.choice([True, False]),
) }
) )
library_agents.append(library_agent) library_agents.append(library_agent)
@@ -398,13 +392,13 @@ async def main():
user = random.choice(users) user = random.choice(users)
slug = faker.slug() slug = faker.slug()
listing = await db.storelisting.create( listing = await db.storelisting.create(
data=StoreListingCreateInput( data={
agentGraphId=graph.id, "agentGraphId": graph.id,
agentGraphVersion=graph.version, "agentGraphVersion": graph.version,
owningUserId=user.id, "owningUserId": user.id,
hasApprovedVersion=random.choice([True, False]), "hasApprovedVersion": random.choice([True, False]),
slug=slug, "slug": slug,
) }
) )
store_listings.append(listing) store_listings.append(listing)
@@ -414,26 +408,26 @@ async def main():
for listing in store_listings: for listing in store_listings:
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0] graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
version = await db.storelistingversion.create( version = await db.storelistingversion.create(
data=StoreListingVersionCreateInput( data={
agentGraphId=graph.id, "agentGraphId": graph.id,
agentGraphVersion=graph.version, "agentGraphVersion": graph.version,
name=graph.name or faker.sentence(nb_words=3), "name": graph.name or faker.sentence(nb_words=3),
subHeading=faker.sentence(), "subHeading": faker.sentence(),
videoUrl=get_video_url() if random.random() < 0.3 else None, "videoUrl": get_video_url() if random.random() < 0.3 else None,
imageUrls=[get_image() for _ in range(3)], "imageUrls": [get_image() for _ in range(3)],
description=faker.text(), "description": faker.text(),
categories=[faker.word() for _ in range(3)], "categories": [faker.word() for _ in range(3)],
isFeatured=random.choice([True, False]), "isFeatured": random.choice([True, False]),
isAvailable=True, "isAvailable": True,
storeListingId=listing.id, "storeListingId": listing.id,
submissionStatus=random.choice( "submissionStatus": random.choice(
[ [
prisma.enums.SubmissionStatus.PENDING, prisma.enums.SubmissionStatus.PENDING,
prisma.enums.SubmissionStatus.APPROVED, prisma.enums.SubmissionStatus.APPROVED,
prisma.enums.SubmissionStatus.REJECTED, prisma.enums.SubmissionStatus.REJECTED,
] ]
), ),
) }
) )
store_listing_versions.append(version) store_listing_versions.append(version)
@@ -475,47 +469,51 @@ async def main():
try: try:
await db.useronboarding.create( await db.useronboarding.create(
data=UserOnboardingCreateInput( data={
userId=user.id, "userId": user.id,
completedSteps=completed_steps, "completedSteps": completed_steps,
walletShown=random.choice([True, False]), "walletShown": random.choice([True, False]),
notified=( "notified": (
random.sample(completed_steps, k=min(3, len(completed_steps))) random.sample(completed_steps, k=min(3, len(completed_steps)))
if completed_steps if completed_steps
else [] else []
), ),
rewardedFor=( "rewardedFor": (
random.sample(completed_steps, k=min(2, len(completed_steps))) random.sample(completed_steps, k=min(2, len(completed_steps)))
if completed_steps if completed_steps
else [] else []
), ),
usageReason=( "usageReason": (
random.choice(["personal", "business", "research", "learning"]) random.choice(["personal", "business", "research", "learning"])
if random.random() < 0.7 if random.random() < 0.7
else None else None
), ),
integrations=random.sample( "integrations": random.sample(
["github", "google", "discord", "slack"], k=random.randint(0, 2) ["github", "google", "discord", "slack"], k=random.randint(0, 2)
), ),
otherIntegrations=(faker.word() if random.random() < 0.2 else None), "otherIntegrations": (
selectedStoreListingVersionId=( faker.word() if random.random() < 0.2 else None
),
"selectedStoreListingVersionId": (
random.choice(store_listing_versions).id random.choice(store_listing_versions).id
if store_listing_versions and random.random() < 0.5 if store_listing_versions and random.random() < 0.5
else None else None
), ),
onboardingAgentExecutionId=( "onboardingAgentExecutionId": (
random.choice(agent_graph_executions).id random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3 if agent_graph_executions and random.random() < 0.3
else None else None
), ),
agentRuns=random.randint(0, 10), "agentRuns": random.randint(0, 10),
) }
) )
except Exception as e: except Exception as e:
print(f"Error creating onboarding for user {user.id}: {e}") print(f"Error creating onboarding for user {user.id}: {e}")
# Try simpler version # Try simpler version
await db.useronboarding.create( await db.useronboarding.create(
data=UserOnboardingCreateInput(userId=user.id) data={
"userId": user.id,
}
) )
# Insert IntegrationWebhooks for some users # Insert IntegrationWebhooks for some users
@@ -546,20 +544,20 @@ async def main():
for user in users: for user in users:
api_key = APIKeySmith().generate_key() api_key = APIKeySmith().generate_key()
await db.apikey.create( await db.apikey.create(
data=APIKeyCreateInput( data={
name=faker.word(), "name": faker.word(),
head=api_key.head, "head": api_key.head,
tail=api_key.tail, "tail": api_key.tail,
hash=api_key.hash, "hash": api_key.hash,
salt=api_key.salt, "salt": api_key.salt,
status=prisma.enums.APIKeyStatus.ACTIVE, "status": prisma.enums.APIKeyStatus.ACTIVE,
permissions=[ "permissions": [
prisma.enums.APIKeyPermission.EXECUTE_GRAPH, prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
prisma.enums.APIKeyPermission.READ_GRAPH, prisma.enums.APIKeyPermission.READ_GRAPH,
], ],
description=faker.text(), "description": faker.text(),
userId=user.id, "userId": user.id,
) }
) )
# Refresh materialized views # Refresh materialized views

View File

@@ -16,7 +16,6 @@ from datetime import datetime, timedelta
import prisma.enums import prisma.enums
from faker import Faker from faker import Faker
from prisma import Json, Prisma from prisma import Json, Prisma
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
faker = Faker() faker = Faker()
@@ -167,16 +166,16 @@ async def main():
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0] score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
await db.storelistingreview.create( await db.storelistingreview.create(
data=StoreListingReviewCreateInput( data={
storeListingVersionId=version.id, "storeListingVersionId": version.id,
reviewByUserId=reviewer.id, "reviewByUserId": reviewer.id,
score=score, "score": score,
comments=( "comments": (
faker.text(max_nb_chars=200) faker.text(max_nb_chars=200)
if random.random() < 0.7 if random.random() < 0.7
else None else None
), ),
) }
) )
new_reviews_count += 1 new_reviews_count += 1
@@ -245,17 +244,17 @@ async def main():
) )
await db.credittransaction.create( await db.credittransaction.create(
data=CreditTransactionCreateInput( data={
userId=user.id, "userId": user.id,
amount=amount, "amount": amount,
type=transaction_type, "type": transaction_type,
metadata=Json( "metadata": Json(
{ {
"source": "test_updater", "source": "test_updater",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
), ),
) }
) )
transaction_count += 1 transaction_count += 1

View File

@@ -82,7 +82,7 @@
"lodash": "4.17.21", "lodash": "4.17.21",
"lucide-react": "0.552.0", "lucide-react": "0.552.0",
"moment": "2.30.1", "moment": "2.30.1",
"next": "15.4.8", "next": "15.4.10",
"next-themes": "0.4.6", "next-themes": "0.4.6",
"nuqs": "2.7.2", "nuqs": "2.7.2",
"party-js": "2.2.0", "party-js": "2.2.0",

View File

@@ -16,7 +16,7 @@ importers:
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1)) version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
'@next/third-parties': '@next/third-parties':
specifier: 15.4.6 specifier: 15.4.6
version: 15.4.6(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) version: 15.4.6(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@phosphor-icons/react': '@phosphor-icons/react':
specifier: 2.1.10 specifier: 2.1.10
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -88,7 +88,7 @@ importers:
version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1)) version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1))
'@sentry/nextjs': '@sentry/nextjs':
specifier: 10.27.0 specifier: 10.27.0
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9)) version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))
'@supabase/ssr': '@supabase/ssr':
specifier: 0.7.0 specifier: 0.7.0
version: 0.7.0(@supabase/supabase-js@2.78.0) version: 0.7.0(@supabase/supabase-js@2.78.0)
@@ -106,10 +106,10 @@ importers:
version: 0.2.4 version: 0.2.4
'@vercel/analytics': '@vercel/analytics':
specifier: 1.5.0 specifier: 1.5.0
version: 1.5.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) version: 1.5.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@vercel/speed-insights': '@vercel/speed-insights':
specifier: 1.2.0 specifier: 1.2.0
version: 1.2.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) version: 1.2.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
'@xyflow/react': '@xyflow/react':
specifier: 12.9.2 specifier: 12.9.2
version: 12.9.2(@types/react@18.3.17)(immer@10.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 12.9.2(@types/react@18.3.17)(immer@10.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -148,7 +148,7 @@ importers:
version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
geist: geist:
specifier: 1.5.1 specifier: 1.5.1
version: 1.5.1(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)) version: 1.5.1(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
highlight.js: highlight.js:
specifier: 11.11.1 specifier: 11.11.1
version: 11.11.1 version: 11.11.1
@@ -171,14 +171,14 @@ importers:
specifier: 2.30.1 specifier: 2.30.1
version: 2.30.1 version: 2.30.1
next: next:
specifier: 15.4.8 specifier: 15.4.10
version: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
next-themes: next-themes:
specifier: 0.4.6 specifier: 0.4.6
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
nuqs: nuqs:
specifier: 2.7.2 specifier: 2.7.2
version: 2.7.2(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1) version: 2.7.2(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
party-js: party-js:
specifier: 2.2.0 specifier: 2.2.0
version: 2.2.0 version: 2.2.0
@@ -284,7 +284,7 @@ importers:
version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2)) version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
'@storybook/nextjs': '@storybook/nextjs':
specifier: 9.1.5 specifier: 9.1.5
version: 9.1.5(esbuild@0.25.9)(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9)) version: 9.1.5(esbuild@0.25.9)(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))
'@tanstack/eslint-plugin-query': '@tanstack/eslint-plugin-query':
specifier: 5.91.2 specifier: 5.91.2
version: 5.91.2(eslint@8.57.1)(typescript@5.9.3) version: 5.91.2(eslint@8.57.1)(typescript@5.9.3)
@@ -1602,8 +1602,8 @@ packages:
'@neoconfetti/react@1.0.0': '@neoconfetti/react@1.0.0':
resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==} resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==}
'@next/env@15.4.8': '@next/env@15.4.10':
resolution: {integrity: sha512-LydLa2MDI1NMrOFSkO54mTc8iIHSttj6R6dthITky9ylXV2gCGi0bHQjVCtLGRshdRPjyh2kXbxJukDtBWQZtQ==} resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==}
'@next/eslint-plugin-next@15.5.2': '@next/eslint-plugin-next@15.5.2':
resolution: {integrity: sha512-lkLrRVxcftuOsJNhWatf1P2hNVfh98k/omQHrCEPPriUypR6RcS13IvLdIrEvkm9AH2Nu2YpR5vLqBuy6twH3Q==} resolution: {integrity: sha512-lkLrRVxcftuOsJNhWatf1P2hNVfh98k/omQHrCEPPriUypR6RcS13IvLdIrEvkm9AH2Nu2YpR5vLqBuy6twH3Q==}
@@ -5920,8 +5920,8 @@ packages:
react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
next@15.4.8: next@15.4.10:
resolution: {integrity: sha512-jwOXTz/bo0Pvlf20FSb6VXVeWRssA2vbvq9SdrOPEg9x8E1B27C2rQtvriAn600o9hH61kjrVRexEffv3JybuA==} resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==}
engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0} engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0}
hasBin: true hasBin: true
peerDependencies: peerDependencies:
@@ -9003,7 +9003,7 @@ snapshots:
'@neoconfetti/react@1.0.0': {} '@neoconfetti/react@1.0.0': {}
'@next/env@15.4.8': {} '@next/env@15.4.10': {}
'@next/eslint-plugin-next@15.5.2': '@next/eslint-plugin-next@15.5.2':
dependencies: dependencies:
@@ -9033,9 +9033,9 @@ snapshots:
'@next/swc-win32-x64-msvc@15.4.8': '@next/swc-win32-x64-msvc@15.4.8':
optional: true optional: true
'@next/third-parties@15.4.6(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': '@next/third-parties@15.4.6(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
dependencies: dependencies:
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1 react: 18.3.1
third-party-capital: 1.0.20 third-party-capital: 1.0.20
@@ -10267,7 +10267,7 @@ snapshots:
'@sentry/core@10.27.0': {} '@sentry/core@10.27.0': {}
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))': '@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))':
dependencies: dependencies:
'@opentelemetry/api': 1.9.0 '@opentelemetry/api': 1.9.0
'@opentelemetry/semantic-conventions': 1.37.0 '@opentelemetry/semantic-conventions': 1.37.0
@@ -10280,7 +10280,7 @@ snapshots:
'@sentry/react': 10.27.0(react@18.3.1) '@sentry/react': 10.27.0(react@18.3.1)
'@sentry/vercel-edge': 10.27.0 '@sentry/vercel-edge': 10.27.0
'@sentry/webpack-plugin': 4.3.0(webpack@5.101.3(esbuild@0.25.9)) '@sentry/webpack-plugin': 4.3.0(webpack@5.101.3(esbuild@0.25.9))
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
resolve: 1.22.8 resolve: 1.22.8
rollup: 4.52.2 rollup: 4.52.2
stacktrace-parser: 0.1.11 stacktrace-parser: 0.1.11
@@ -10642,7 +10642,7 @@ snapshots:
react: 18.3.1 react: 18.3.1
react-dom: 18.3.1(react@18.3.1) react-dom: 18.3.1(react@18.3.1)
'@storybook/nextjs@9.1.5(esbuild@0.25.9)(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))': '@storybook/nextjs@9.1.5(esbuild@0.25.9)(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))':
dependencies: dependencies:
'@babel/core': 7.28.4 '@babel/core': 7.28.4
'@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.4) '@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.4)
@@ -10666,7 +10666,7 @@ snapshots:
css-loader: 6.11.0(webpack@5.101.3(esbuild@0.25.9)) css-loader: 6.11.0(webpack@5.101.3(esbuild@0.25.9))
image-size: 2.0.2 image-size: 2.0.2
loader-utils: 3.3.1 loader-utils: 3.3.1
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
node-polyfill-webpack-plugin: 2.0.1(webpack@5.101.3(esbuild@0.25.9)) node-polyfill-webpack-plugin: 2.0.1(webpack@5.101.3(esbuild@0.25.9))
postcss: 8.5.6 postcss: 8.5.6
postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.101.3(esbuild@0.25.9)) postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.101.3(esbuild@0.25.9))
@@ -11271,14 +11271,14 @@ snapshots:
'@unrs/resolver-binding-win32-x64-msvc@1.11.1': '@unrs/resolver-binding-win32-x64-msvc@1.11.1':
optional: true optional: true
'@vercel/analytics@1.5.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': '@vercel/analytics@1.5.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
optionalDependencies: optionalDependencies:
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1 react: 18.3.1
'@vercel/speed-insights@1.2.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)': '@vercel/speed-insights@1.2.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
optionalDependencies: optionalDependencies:
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
react: 18.3.1 react: 18.3.1
'@vitest/expect@3.2.4': '@vitest/expect@3.2.4':
@@ -12954,9 +12954,9 @@ snapshots:
functions-have-names@1.2.3: {} functions-have-names@1.2.3: {}
geist@1.5.1(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)): geist@1.5.1(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
dependencies: dependencies:
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
gensync@1.0.0-beta.2: {} gensync@1.0.0-beta.2: {}
@@ -14226,9 +14226,9 @@ snapshots:
react: 18.3.1 react: 18.3.1
react-dom: 18.3.1(react@18.3.1) react-dom: 18.3.1(react@18.3.1)
next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies: dependencies:
'@next/env': 15.4.8 '@next/env': 15.4.10
'@swc/helpers': 0.5.15 '@swc/helpers': 0.5.15
caniuse-lite: 1.0.30001741 caniuse-lite: 1.0.30001741
postcss: 8.4.31 postcss: 8.4.31
@@ -14321,12 +14321,12 @@ snapshots:
dependencies: dependencies:
boolbase: 1.0.0 boolbase: 1.0.0
nuqs@2.7.2(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1): nuqs@2.7.2(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
dependencies: dependencies:
'@standard-schema/spec': 1.0.0 '@standard-schema/spec': 1.0.0
react: 18.3.1 react: 18.3.1
optionalDependencies: optionalDependencies:
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
oas-kit-common@1.0.8: oas-kit-common@1.0.8:
dependencies: dependencies:

View File

@@ -8,8 +8,6 @@ import { shouldShowOnboarding } from "@/app/api/helpers";
export async function GET(request: Request) { export async function GET(request: Request) {
const { searchParams, origin } = new URL(request.url); const { searchParams, origin } = new URL(request.url);
const code = searchParams.get("code"); const code = searchParams.get("code");
const oauthSession = searchParams.get("oauth_session");
const connectSession = searchParams.get("connect_session");
let next = "/marketplace"; let next = "/marketplace";
@@ -27,22 +25,6 @@ export async function GET(request: Request) {
const api = new BackendAPI(); const api = new BackendAPI();
await api.createUser(); await api.createUser();
// Handle oauth_session redirect - resume OAuth flow after login
// Redirect to a frontend page that will handle the OAuth resume with proper auth
if (oauthSession) {
return NextResponse.redirect(
`${origin}/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`,
);
}
// Handle connect_session redirect - resume connect flow after login
// Redirect to a frontend page that will handle the connect resume with proper auth
if (connectSession) {
return NextResponse.redirect(
`${origin}/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`,
);
}
if (await shouldShowOnboarding()) { if (await shouldShowOnboarding()) {
next = "/onboarding"; next = "/onboarding";
revalidatePath("/onboarding", "layout"); revalidatePath("/onboarding", "layout");

View File

@@ -1,400 +0,0 @@
"use client";
import { useEffect, useState, useRef, useCallback } from "react";
import { useSearchParams } from "next/navigation";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { getWebSocketToken } from "@/lib/supabase/actions";
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
const attemptedSessions = new Set<string>();
interface ScopeInfo {
scope: string;
description: string;
}
interface CredentialInfo {
id: string;
title: string;
username: string;
}
interface ClientInfo {
name: string;
logo_url: string | null;
}
interface ConnectData {
connect_token: string;
client: ClientInfo;
provider: string;
scopes: ScopeInfo[];
credentials: CredentialInfo[];
action_url: string;
}
interface ErrorData {
error: string;
error_description: string;
}
type ResumeResponse = ConnectData | ErrorData;
function isConnectData(data: ResumeResponse): data is ConnectData {
return "connect_token" in data;
}
function isErrorData(data: ResumeResponse): data is ErrorData {
return "error" in data;
}
/**
* Connect Consent Form Component
*
* Renders a proper React component for the integration connect consent form
*/
function ConnectForm({
client,
provider,
scopes,
credentials,
connectToken,
actionUrl,
}: {
client: ClientInfo;
provider: string;
scopes: ScopeInfo[];
credentials: CredentialInfo[];
connectToken: string;
actionUrl: string;
}) {
const [isSubmitting, setIsSubmitting] = useState(false);
const [selectedCredential, setSelectedCredential] = useState<string>(
credentials.length > 0 ? credentials[0].id : "",
);
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
const backendOrigin = backendUrl
? new URL(backendUrl).origin
: "http://localhost:8006";
const fullActionUrl = `${backendOrigin}${actionUrl}`;
function handleSubmit() {
setIsSubmitting(true);
}
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
{/* Header */}
<div className="mb-6 text-center">
<h1 className="text-xl font-semibold text-zinc-100">
Connect{" "}
<span className="rounded bg-zinc-700 px-2 py-1 text-sm capitalize">
{provider}
</span>
</h1>
<p className="mt-2 text-sm text-zinc-400">
<span className="font-semibold text-cyan-400">{client.name}</span>{" "}
wants to use your {provider} integration
</p>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Scopes Section */}
<div className="mb-6">
<h2 className="mb-4 text-sm font-medium text-zinc-400">
This will allow {client.name} to:
</h2>
<div className="space-y-2">
{scopes.map((scope) => (
<div key={scope.scope} className="flex items-start gap-2 py-2">
<span className="flex-shrink-0 text-cyan-400">&#10003;</span>
<span className="text-sm text-zinc-300">
{scope.description}
</span>
</div>
))}
</div>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Form */}
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
<input type="hidden" name="connect_token" value={connectToken} />
{/* Existing credentials selection */}
{credentials.length > 0 && (
<>
<h3 className="mb-3 text-sm font-medium text-zinc-400">
Select an existing credential:
</h3>
<div className="mb-4 space-y-2">
{credentials.map((cred) => (
<label
key={cred.id}
className={`flex cursor-pointer items-center gap-3 rounded-lg border p-3 transition-colors ${
selectedCredential === cred.id
? "border-cyan-400 bg-cyan-400/10"
: "border-zinc-700 hover:border-cyan-400/50"
}`}
>
<input
type="radio"
name="credential_id"
value={cred.id}
checked={selectedCredential === cred.id}
onChange={() => setSelectedCredential(cred.id)}
className="hidden"
/>
<div>
<div className="text-sm font-medium text-zinc-200">
{cred.title}
</div>
{cred.username && (
<div className="text-xs text-zinc-500">
{cred.username}
</div>
)}
</div>
</label>
))}
</div>
<div className="my-4 h-px bg-zinc-700" />
</>
)}
{/* Connect new account */}
<div className="mb-4">
{credentials.length > 0 ? (
<h3 className="mb-3 text-sm font-medium text-zinc-400">
Or connect a new account:
</h3>
) : (
<p className="mb-3 text-sm text-zinc-400">
You don&apos;t have any {provider} credentials yet.
</p>
)}
<button
type="submit"
name="action"
value="connect_new"
disabled={isSubmitting}
className="w-full rounded-lg bg-blue-500 px-6 py-3 text-sm font-medium text-white transition-colors hover:bg-blue-400 disabled:cursor-not-allowed disabled:opacity-50"
>
Connect {provider.charAt(0).toUpperCase() + provider.slice(1)}{" "}
Account
</button>
</div>
{/* Action buttons */}
<div className="flex gap-3">
<button
type="submit"
name="action"
value="deny"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
>
Cancel
</button>
{credentials.length > 0 && (
<button
type="submit"
name="action"
value="approve"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
>
{isSubmitting ? "Approving..." : "Approve"}
</button>
)}
</div>
</form>
</div>
</div>
);
}
/**
* Connect Resume Page
*
* This page handles resuming the integration connect flow after a user logs in.
* It fetches the connect data from the backend via JSON API and renders the consent form.
*/
export default function ConnectResumePage() {
const searchParams = useSearchParams();
const sessionId = searchParams.get("session_id");
const { isUserLoading, refreshSession } = useSupabase();
const [connectData, setConnectData] = useState<ConnectData | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true);
const retryCountRef = useRef(0);
const maxRetries = 5;
const resumeConnectFlow = useCallback(async () => {
if (!sessionId) {
setError(
"Missing session ID. Please start the connection process again.",
);
setIsLoading(false);
return;
}
if (attemptedSessions.has(sessionId)) {
return;
}
if (isUserLoading) {
return;
}
attemptedSessions.add(sessionId);
try {
let tokenResult = await getWebSocketToken();
let accessToken = tokenResult.token;
while (!accessToken && retryCountRef.current < maxRetries) {
retryCountRef.current += 1;
console.log(
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
);
await refreshSession();
await new Promise((resolve) => setTimeout(resolve, 1000));
tokenResult = await getWebSocketToken();
accessToken = tokenResult.token;
}
if (!accessToken) {
setError(
"Unable to retrieve authentication token. Please log in again.",
);
setIsLoading(false);
return;
}
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
if (!backendUrl) {
setError("Backend URL not configured.");
setIsLoading(false);
return;
}
let backendOrigin: string;
try {
const url = new URL(backendUrl);
backendOrigin = url.origin;
} catch {
setError("Invalid backend URL configuration.");
setIsLoading(false);
return;
}
const response = await fetch(
`${backendOrigin}/connect/resume?session_id=${encodeURIComponent(sessionId)}`,
{
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
Accept: "application/json",
},
},
);
const data: ResumeResponse = await response.json();
if (!response.ok) {
if (isErrorData(data)) {
setError(data.error_description || data.error);
} else {
setError(`Connection failed (${response.status}). Please try again.`);
}
setIsLoading(false);
return;
}
if (isConnectData(data)) {
setConnectData(data);
setIsLoading(false);
return;
}
setError("Unexpected response from server. Please try again.");
setIsLoading(false);
} catch (err) {
console.error("Connect resume error:", err);
setError(
"An error occurred while resuming connection. Please try again.",
);
setIsLoading(false);
}
}, [sessionId, isUserLoading, refreshSession]);
useEffect(() => {
resumeConnectFlow();
}, [resumeConnectFlow]);
if (isLoading || isUserLoading) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="text-center">
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
<p className="text-zinc-400">Resuming connection...</p>
</div>
</div>
);
}
if (error) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
<svg
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<circle cx="12" cy="12" r="10" />
<line x1="15" y1="9" x2="9" y2="15" />
<line x1="9" y1="9" x2="15" y2="15" />
</svg>
</div>
<h1 className="mb-2 text-xl font-semibold text-red-400">
Connection Error
</h1>
<p className="mb-6 text-zinc-400">{error}</p>
<button
onClick={() => window.close()}
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
>
Close
</button>
</div>
</div>
);
}
if (connectData) {
return (
<ConnectForm
client={connectData.client}
provider={connectData.provider}
scopes={connectData.scopes}
credentials={connectData.credentials}
connectToken={connectData.connect_token}
actionUrl={connectData.action_url}
/>
);
}
return null;
}

View File

@@ -22,28 +22,20 @@ export async function GET(request: Request) {
console.debug("Sending message to opener:", message); console.debug("Sending message to opener:", message);
// Escape JSON to prevent XSS attacks via </script> injection
const safeJson = JSON.stringify(message)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e");
// Return a response with the message as JSON and a script to close the window // Return a response with the message as JSON and a script to close the window
return new NextResponse( return new NextResponse(
`<!DOCTYPE html> `
<html> <html>
<body> <body>
<script> <script>
window.opener.postMessage(${safeJson}, '*'); window.opener.postMessage(${JSON.stringify(message)});
window.close(); window.close();
</script> </script>
</body> </body>
</html>`, </html>
`,
{ {
headers: { headers: { "Content-Type": "text/html" },
"Content-Type": "text/html",
"Content-Security-Policy":
"default-src 'none'; script-src 'unsafe-inline'",
},
}, },
); );
} }

View File

@@ -1,399 +0,0 @@
"use client";
import { useEffect, useState, useRef, useCallback } from "react";
import { useSearchParams } from "next/navigation";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { getWebSocketToken } from "@/lib/supabase/actions";
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
// This is keyed by session_id to allow different sessions
const attemptedSessions = new Set<string>();
interface ScopeInfo {
scope: string;
description: string;
}
interface ClientInfo {
name: string;
logo_url: string | null;
privacy_policy_url: string | null;
terms_url: string | null;
}
interface ConsentData {
needs_consent: true;
consent_token: string;
client: ClientInfo;
scopes: ScopeInfo[];
action_url: string;
}
interface RedirectData {
redirect_url: string;
needs_consent: false;
}
interface ErrorData {
error: string;
error_description: string;
redirect_url?: string;
}
type ResumeResponse = ConsentData | RedirectData | ErrorData;
function isConsentData(data: ResumeResponse): data is ConsentData {
return "needs_consent" in data && data.needs_consent === true;
}
function isRedirectData(data: ResumeResponse): data is RedirectData {
return "redirect_url" in data && !("error" in data);
}
function isErrorData(data: ResumeResponse): data is ErrorData {
return "error" in data;
}
/**
* OAuth Consent Form Component
*
* Renders a proper React component for the consent form instead of dangerouslySetInnerHTML
*/
function ConsentForm({
client,
scopes,
consentToken,
actionUrl,
}: {
client: ClientInfo;
scopes: ScopeInfo[];
consentToken: string;
actionUrl: string;
}) {
const [isSubmitting, setIsSubmitting] = useState(false);
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
const backendOrigin = backendUrl
? new URL(backendUrl).origin
: "http://localhost:8006";
// Full action URL for form submission
const fullActionUrl = `${backendOrigin}${actionUrl}`;
function handleSubmit() {
setIsSubmitting(true);
}
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
{/* Header */}
<div className="mb-6 text-center">
<div className="mx-auto mb-4 flex h-16 w-16 items-center justify-center rounded-xl bg-zinc-700">
{client.logo_url ? (
<img
src={client.logo_url}
alt={client.name}
className="h-12 w-12 rounded-lg"
/>
) : (
<span className="text-3xl text-zinc-400">
{client.name.charAt(0).toUpperCase()}
</span>
)}
</div>
<h1 className="text-xl font-semibold text-zinc-100">
Authorize <span className="text-cyan-400">{client.name}</span>
</h1>
<p className="mt-2 text-sm text-zinc-400">
wants to access your AutoGPT account
</p>
</div>
{/* Divider */}
<div className="my-6 h-px bg-zinc-700" />
{/* Scopes Section */}
<div className="mb-6">
<h2 className="mb-4 text-sm font-medium text-zinc-400">
This will allow {client.name} to:
</h2>
<div className="space-y-3">
{scopes.map((scope) => (
<div
key={scope.scope}
className="flex items-start gap-3 border-b border-zinc-700 pb-3 last:border-0"
>
<svg
className="mt-0.5 h-5 w-5 flex-shrink-0 text-cyan-400"
viewBox="0 0 20 20"
fill="currentColor"
>
<path
fillRule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clipRule="evenodd"
/>
</svg>
<span className="text-sm leading-relaxed text-zinc-300">
{scope.description}
</span>
</div>
))}
</div>
</div>
{/* Form */}
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
<input type="hidden" name="consent_token" value={consentToken} />
<div className="flex gap-3">
<button
type="submit"
name="authorize"
value="false"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
>
Cancel
</button>
<button
type="submit"
name="authorize"
value="true"
disabled={isSubmitting}
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
>
{isSubmitting ? "Authorizing..." : "Allow"}
</button>
</div>
</form>
{/* Footer Links */}
{(client.privacy_policy_url || client.terms_url) && (
<div className="mt-6 text-center text-xs text-zinc-500">
{client.privacy_policy_url && (
<a
href={client.privacy_policy_url}
target="_blank"
rel="noopener noreferrer"
className="text-zinc-400 hover:underline"
>
Privacy Policy
</a>
)}
{client.privacy_policy_url && client.terms_url && (
<span className="mx-2"></span>
)}
{client.terms_url && (
<a
href={client.terms_url}
target="_blank"
rel="noopener noreferrer"
className="text-zinc-400 hover:underline"
>
Terms of Service
</a>
)}
</div>
)}
</div>
</div>
);
}
/**
* OAuth Resume Page
*
* This page handles resuming the OAuth authorization flow after a user logs in.
* It fetches the consent data from the backend via JSON API and renders the consent form.
*/
export default function OAuthResumePage() {
const searchParams = useSearchParams();
const sessionId = searchParams.get("session_id");
const { isUserLoading, refreshSession } = useSupabase();
const [consentData, setConsentData] = useState<ConsentData | null>(null);
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true);
const retryCountRef = useRef(0);
const maxRetries = 5;
const resumeOAuthFlow = useCallback(async () => {
// Prevent multiple attempts for the same session (handles React StrictMode)
if (!sessionId) {
setError(
"Missing session ID. Please start the authorization process again.",
);
setIsLoading(false);
return;
}
if (attemptedSessions.has(sessionId)) {
// Already attempted this session, don't retry
return;
}
if (isUserLoading) {
return; // Wait for auth state to load
}
// Mark this session as attempted IMMEDIATELY to prevent duplicate requests
attemptedSessions.add(sessionId);
try {
// Get the access token from server action (which reads cookies properly)
let tokenResult = await getWebSocketToken();
let accessToken = tokenResult.token;
// If no token, retry a few times with delays
while (!accessToken && retryCountRef.current < maxRetries) {
retryCountRef.current += 1;
console.log(
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
);
// Try refreshing the session
await refreshSession();
await new Promise((resolve) => setTimeout(resolve, 1000));
tokenResult = await getWebSocketToken();
accessToken = tokenResult.token;
}
if (!accessToken) {
setError(
"Unable to retrieve authentication token. Please log in again.",
);
setIsLoading(false);
return;
}
// Call the backend resume endpoint with JSON accept header
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
if (!backendUrl) {
setError("Backend URL not configured.");
setIsLoading(false);
return;
}
// Extract the origin from the backend URL
let backendOrigin: string;
try {
const url = new URL(backendUrl);
backendOrigin = url.origin;
} catch {
setError("Invalid backend URL configuration.");
setIsLoading(false);
return;
}
// Use Accept: application/json to get JSON response instead of HTML
// This solves the CORS/redirect issue by letting us handle redirects client-side
const response = await fetch(
`${backendOrigin}/oauth/authorize/resume?session_id=${encodeURIComponent(sessionId)}`,
{
method: "GET",
headers: {
Authorization: `Bearer ${accessToken}`,
Accept: "application/json",
},
},
);
const data: ResumeResponse = await response.json();
if (!response.ok) {
if (isErrorData(data)) {
setError(data.error_description || data.error);
} else {
setError(
`Authorization failed (${response.status}). Please try again.`,
);
}
setIsLoading(false);
return;
}
// Handle redirect response (user already authorized these scopes)
if (isRedirectData(data)) {
window.location.href = data.redirect_url;
return;
}
// Handle consent required
if (isConsentData(data)) {
setConsentData(data);
setIsLoading(false);
return;
}
// Unexpected response
setError("Unexpected response from server. Please try again.");
setIsLoading(false);
} catch (err) {
console.error("OAuth resume error:", err);
setError(
"An error occurred while resuming authorization. Please try again.",
);
setIsLoading(false);
}
}, [sessionId, isUserLoading, refreshSession]);
useEffect(() => {
resumeOAuthFlow();
}, [resumeOAuthFlow]);
if (isLoading || isUserLoading) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="text-center">
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
<p className="text-zinc-400">Resuming authorization...</p>
</div>
</div>
);
}
if (error) {
return (
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
<svg
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
>
<circle cx="12" cy="12" r="10" />
<line x1="15" y1="9" x2="9" y2="15" />
<line x1="9" y1="9" x2="15" y2="15" />
</svg>
</div>
<h1 className="mb-2 text-xl font-semibold text-red-400">
Authorization Error
</h1>
<p className="mb-6 text-zinc-400">{error}</p>
<button
onClick={() => window.close()}
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
>
Close
</button>
</div>
</div>
);
}
if (consentData) {
return (
<ConsentForm
client={consentData.client}
scopes={consentData.scopes}
consentToken={consentData.consent_token}
actionUrl={consentData.action_url}
/>
);
}
return null;
}

View File

@@ -1,24 +1,25 @@
import { useCallback } from "react"; import { useCallback } from "react";
import { useReactFlow } from "@xyflow/react"; import { useReactFlow } from "@xyflow/react";
import { Key, storage } from "@/services/storage/local-storage";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { useNodeStore } from "../../../stores/nodeStore"; import { useNodeStore } from "../../../stores/nodeStore";
import { useEdgeStore } from "../../../stores/edgeStore"; import { useEdgeStore } from "../../../stores/edgeStore";
import { CustomNode } from "../nodes/CustomNode/CustomNode"; import { CustomNode } from "../nodes/CustomNode/CustomNode";
import { CustomEdge } from "../edges/CustomEdge"; import { CustomEdge } from "../edges/CustomEdge";
import { useToast } from "@/components/molecules/Toast/use-toast";
interface CopyableData { interface CopyableData {
nodes: CustomNode[]; nodes: CustomNode[];
edges: CustomEdge[]; edges: CustomEdge[];
} }
const CLIPBOARD_PREFIX = "autogpt-flow-data:";
export function useCopyPaste() { export function useCopyPaste() {
// Only use useReactFlow for viewport (not managed by stores)
const { getViewport } = useReactFlow(); const { getViewport } = useReactFlow();
const { toast } = useToast();
const handleCopyPaste = useCallback( const handleCopyPaste = useCallback(
(event: KeyboardEvent) => { (event: KeyboardEvent) => {
// Prevent copy/paste if any modal is open or if the focus is on an input element
const activeElement = document.activeElement; const activeElement = document.activeElement;
const isInputField = const isInputField =
activeElement?.tagName === "INPUT" || activeElement?.tagName === "INPUT" ||
@@ -28,7 +29,6 @@ export function useCopyPaste() {
if (isInputField) return; if (isInputField) return;
if (event.ctrlKey || event.metaKey) { if (event.ctrlKey || event.metaKey) {
// COPY: Ctrl+C or Cmd+C
if (event.key === "c" || event.key === "C") { if (event.key === "c" || event.key === "C") {
const { nodes } = useNodeStore.getState(); const { nodes } = useNodeStore.getState();
const { edges } = useEdgeStore.getState(); const { edges } = useEdgeStore.getState();
@@ -53,17 +53,32 @@ export function useCopyPaste() {
edges: selectedEdges, edges: selectedEdges,
}; };
storage.set(Key.COPIED_FLOW_DATA, JSON.stringify(copiedData)); const clipboardText = `${CLIPBOARD_PREFIX}${JSON.stringify(copiedData)}`;
navigator.clipboard
.writeText(clipboardText)
.then(() => {
toast({
title: "Copied successfully",
description: `${selectedNodes.length} node(s) copied to clipboard`,
});
})
.catch((error) => {
console.error("Failed to copy to clipboard:", error);
});
} }
// PASTE: Ctrl+V or Cmd+V
if (event.key === "v" || event.key === "V") { if (event.key === "v" || event.key === "V") {
const copiedDataString = storage.get(Key.COPIED_FLOW_DATA); navigator.clipboard
if (copiedDataString) { .readText()
const copiedData = JSON.parse(copiedDataString) as CopyableData; .then((clipboardText) => {
if (!clipboardText.startsWith(CLIPBOARD_PREFIX)) {
return; // Not our data, ignore
}
const jsonString = clipboardText.slice(CLIPBOARD_PREFIX.length);
const copiedData = JSON.parse(jsonString) as CopyableData;
const oldToNewIdMap: Record<string, string> = {}; const oldToNewIdMap: Record<string, string> = {};
// Get fresh viewport values at paste time to ensure correct positioning
const { x, y, zoom } = getViewport(); const { x, y, zoom } = getViewport();
const viewportCenter = { const viewportCenter = {
x: (window.innerWidth / 2 - x) / zoom, x: (window.innerWidth / 2 - x) / zoom,
@@ -86,7 +101,10 @@ export function useCopyPaste() {
// Deselect existing nodes first // Deselect existing nodes first
useNodeStore.setState((state) => ({ useNodeStore.setState((state) => ({
nodes: state.nodes.map((node) => ({ ...node, selected: false })), nodes: state.nodes.map((node) => ({
...node,
selected: false,
})),
})); }));
// Create and add new nodes with UNIQUE IDs using UUID // Create and add new nodes with UNIQUE IDs using UUID
@@ -123,11 +141,14 @@ export function useCopyPaste() {
}, },
}); });
}); });
} })
.catch((error) => {
console.error("Failed to read from clipboard:", error);
});
} }
} }
}, },
[getViewport], [getViewport, toast],
); );
return handleCopyPaste; return handleCopyPaste;

View File

@@ -42,7 +42,8 @@ export const useFlow = () => {
const setBlockMenuOpen = useControlPanelStore( const setBlockMenuOpen = useControlPanelStore(
useShallow((state) => state.setBlockMenuOpen), useShallow((state) => state.setBlockMenuOpen),
); );
const [{ flowID, flowVersion, flowExecutionID }] = useQueryStates({ const [{ flowID, flowVersion, flowExecutionID }, setQueryStates] =
useQueryStates({
flowID: parseAsString, flowID: parseAsString,
flowVersion: parseAsInteger, flowVersion: parseAsInteger,
flowExecutionID: parseAsString, flowExecutionID: parseAsString,
@@ -102,6 +103,9 @@ export const useFlow = () => {
// load graph schemas // load graph schemas
useEffect(() => { useEffect(() => {
if (graph) { if (graph) {
setQueryStates({
flowVersion: graph.version ?? 1,
});
setGraphSchemas( setGraphSchemas(
graph.input_schema as Record<string, any> | null, graph.input_schema as Record<string, any> | null,
graph.credentials_input_schema as Record<string, any> | null, graph.credentials_input_schema as Record<string, any> | null,

View File

@@ -1,7 +1,7 @@
import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { useGetV2BuilderSearchInfinite } from "@/app/api/__generated__/endpoints/store/store"; import { useGetV2BuilderSearchInfinite } from "@/app/api/__generated__/endpoints/store/store";
import { SearchResponse } from "@/app/api/__generated__/models/searchResponse"; import { SearchResponse } from "@/app/api/__generated__/models/searchResponse";
import { useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { useAddAgentToBuilder } from "../hooks/useAddAgentToBuilder"; import { useAddAgentToBuilder } from "../hooks/useAddAgentToBuilder";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { getV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store"; import { getV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store";
@@ -9,16 +9,27 @@ import {
getGetV2ListLibraryAgentsQueryKey, getGetV2ListLibraryAgentsQueryKey,
usePostV2AddMarketplaceAgent, usePostV2AddMarketplaceAgent,
} from "@/app/api/__generated__/endpoints/library/library"; } from "@/app/api/__generated__/endpoints/library/library";
import { getGetV2GetBuilderItemCountsQueryKey } from "@/app/api/__generated__/endpoints/default/default"; import {
getGetV2GetBuilderItemCountsQueryKey,
getGetV2GetBuilderSuggestionsQueryKey,
} from "@/app/api/__generated__/endpoints/default/default";
import { getQueryClient } from "@/lib/react-query/queryClient"; import { getQueryClient } from "@/lib/react-query/queryClient";
import { useToast } from "@/components/molecules/Toast/use-toast"; import { useToast } from "@/components/molecules/Toast/use-toast";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
export const useBlockMenuSearch = () => { export const useBlockMenuSearch = () => {
const { searchQuery } = useBlockMenuStore(); const { searchQuery, searchId, setSearchId } = useBlockMenuStore();
const { toast } = useToast(); const { toast } = useToast();
const { addAgentToBuilder, addLibraryAgentToBuilder } = const { addAgentToBuilder, addLibraryAgentToBuilder } =
useAddAgentToBuilder(); useAddAgentToBuilder();
const queryClient = getQueryClient();
const resetSearchSession = useCallback(() => {
setSearchId(undefined);
queryClient.invalidateQueries({
queryKey: getGetV2GetBuilderSuggestionsQueryKey(),
});
}, [queryClient, setSearchId]);
const [addingLibraryAgentId, setAddingLibraryAgentId] = useState< const [addingLibraryAgentId, setAddingLibraryAgentId] = useState<
string | null string | null
@@ -38,13 +49,19 @@ export const useBlockMenuSearch = () => {
page: 1, page: 1,
page_size: 8, page_size: 8,
search_query: searchQuery, search_query: searchQuery,
search_id: searchId,
}, },
{ {
query: { query: {
getNextPageParam: (lastPage, allPages) => { getNextPageParam: (lastPage) => {
const pagination = lastPage.data as SearchResponse; const response = lastPage.data as SearchResponse;
const isMore = pagination.more_pages; const { pagination } = response;
return isMore ? allPages.length + 1 : undefined; if (!pagination) {
return undefined;
}
const { current_page, total_pages } = pagination;
return current_page < total_pages ? current_page + 1 : undefined;
}, },
}, },
}, },
@@ -53,7 +70,6 @@ export const useBlockMenuSearch = () => {
const { mutateAsync: addMarketplaceAgent } = usePostV2AddMarketplaceAgent({ const { mutateAsync: addMarketplaceAgent } = usePostV2AddMarketplaceAgent({
mutation: { mutation: {
onSuccess: () => { onSuccess: () => {
const queryClient = getQueryClient();
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV2ListLibraryAgentsQueryKey(), queryKey: getGetV2ListLibraryAgentsQueryKey(),
}); });
@@ -75,6 +91,24 @@ export const useBlockMenuSearch = () => {
}, },
}); });
useEffect(() => {
if (!searchData?.pages?.length) {
return;
}
const latestPage = searchData.pages[searchData.pages.length - 1];
const response = latestPage?.data as SearchResponse;
if (response?.search_id && response.search_id !== searchId) {
setSearchId(response.search_id);
}
}, [searchData, searchId, setSearchId]);
useEffect(() => {
if (searchId && !searchQuery) {
resetSearchSession();
}
}, [resetSearchSession, searchId, searchQuery]);
const allSearchData = const allSearchData =
searchData?.pages?.flatMap((page) => { searchData?.pages?.flatMap((page) => {
const response = page.data as SearchResponse; const response = page.data as SearchResponse;

View File

@@ -1,30 +1,32 @@
import { debounce } from "lodash"; import { debounce } from "lodash";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { getQueryClient } from "@/lib/react-query/queryClient";
import { getGetV2GetBuilderSuggestionsQueryKey } from "@/app/api/__generated__/endpoints/default/default";
const SEARCH_DEBOUNCE_MS = 300; const SEARCH_DEBOUNCE_MS = 300;
export const useBlockMenuSearchBar = () => { export const useBlockMenuSearchBar = () => {
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const [localQuery, setLocalQuery] = useState(""); const [localQuery, setLocalQuery] = useState("");
const { setSearchQuery, setSearchId, searchId, searchQuery } = const { setSearchQuery, setSearchId, searchQuery } = useBlockMenuStore();
useBlockMenuStore(); const queryClient = getQueryClient();
const searchIdRef = useRef(searchId); const clearSearchSession = useCallback(() => {
useEffect(() => { setSearchId(undefined);
searchIdRef.current = searchId; queryClient.invalidateQueries({
}, [searchId]); queryKey: getGetV2GetBuilderSuggestionsQueryKey(),
});
}, [queryClient, setSearchId]);
const debouncedSetSearchQuery = useCallback( const debouncedSetSearchQuery = useCallback(
debounce((value: string) => { debounce((value: string) => {
setSearchQuery(value); setSearchQuery(value);
if (value.length === 0) { if (value.length === 0) {
setSearchId(undefined); clearSearchSession();
} else if (!searchIdRef.current) {
setSearchId(crypto.randomUUID());
} }
}, SEARCH_DEBOUNCE_MS), }, SEARCH_DEBOUNCE_MS),
[setSearchQuery, setSearchId], [clearSearchSession, setSearchQuery],
); );
useEffect(() => { useEffect(() => {
@@ -36,13 +38,13 @@ export const useBlockMenuSearchBar = () => {
const handleClear = () => { const handleClear = () => {
setLocalQuery(""); setLocalQuery("");
setSearchQuery(""); setSearchQuery("");
setSearchId(undefined); clearSearchSession();
debouncedSetSearchQuery.cancel(); debouncedSetSearchQuery.cancel();
}; };
useEffect(() => { useEffect(() => {
setLocalQuery(searchQuery); setLocalQuery(searchQuery);
}, []); }, [searchQuery]);
return { return {
handleClear, handleClear,

View File

@@ -0,0 +1,109 @@
import React, { useEffect, useRef, useState } from "react";
import { ArrowLeftIcon, ArrowRightIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
interface HorizontalScrollAreaProps {
children: React.ReactNode;
wrapperClassName?: string;
scrollContainerClassName?: string;
scrollAmount?: number;
dependencyList?: React.DependencyList;
}
const defaultDependencies: React.DependencyList = [];
const baseScrollClasses =
"flex gap-2 overflow-x-auto px-8 [scrollbar-width:none] [-ms-overflow-style:'none'] [&::-webkit-scrollbar]:hidden";
export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
children,
wrapperClassName,
scrollContainerClassName,
scrollAmount = 300,
dependencyList = defaultDependencies,
}) => {
const scrollRef = useRef<HTMLDivElement | null>(null);
const [canScrollLeft, setCanScrollLeft] = useState(false);
const [canScrollRight, setCanScrollRight] = useState(false);
const scrollByDelta = (delta: number) => {
if (!scrollRef.current) {
return;
}
scrollRef.current.scrollBy({ left: delta, behavior: "smooth" });
};
const updateScrollState = () => {
const element = scrollRef.current;
if (!element) {
setCanScrollLeft(false);
setCanScrollRight(false);
return;
}
setCanScrollLeft(element.scrollLeft > 0);
setCanScrollRight(
Math.ceil(element.scrollLeft + element.clientWidth) < element.scrollWidth,
);
};
useEffect(() => {
updateScrollState();
const element = scrollRef.current;
if (!element) {
return;
}
const handleScroll = () => updateScrollState();
element.addEventListener("scroll", handleScroll);
window.addEventListener("resize", handleScroll);
return () => {
element.removeEventListener("scroll", handleScroll);
window.removeEventListener("resize", handleScroll);
};
}, dependencyList);
return (
<div className={wrapperClassName}>
<div className="group relative">
<div
ref={scrollRef}
className={cn(baseScrollClasses, scrollContainerClassName)}
>
{children}
</div>
{canScrollLeft && (
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
)}
{canScrollRight && (
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
)}
{canScrollLeft && (
<button
type="button"
aria-label="Scroll left"
className="pointer-events-none absolute left-2 top-5 -translate-y-1/2 opacity-0 transition-opacity duration-200 group-hover:pointer-events-auto group-hover:opacity-100"
onClick={() => scrollByDelta(-scrollAmount)}
>
<ArrowLeftIcon
size={28}
className="rounded-full bg-zinc-700 p-1 text-white drop-shadow"
weight="light"
/>
</button>
)}
{canScrollRight && (
<button
type="button"
aria-label="Scroll right"
className="pointer-events-none absolute right-2 top-5 -translate-y-1/2 opacity-0 transition-opacity duration-200 group-hover:pointer-events-auto group-hover:opacity-100"
onClick={() => scrollByDelta(scrollAmount)}
>
<ArrowRightIcon
size={28}
className="rounded-full bg-zinc-700 p-1 text-white drop-shadow"
weight="light"
/>
</button>
)}
</div>
</div>
);
};

View File

@@ -6,10 +6,15 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { blockMenuContainerStyle } from "../style"; import { blockMenuContainerStyle } from "../style";
import { useBlockMenuStore } from "../../../../stores/blockMenuStore"; import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
import { DefaultStateType } from "../types"; import { DefaultStateType } from "../types";
import { SearchHistoryChip } from "../SearchHistoryChip";
import { HorizontalScroll } from "../HorizontalScroll";
export const SuggestionContent = () => { export const SuggestionContent = () => {
const { setIntegration, setDefaultState } = useBlockMenuStore(); const { setIntegration, setDefaultState, setSearchQuery, setSearchId } =
useBlockMenuStore();
const { data, isLoading, isError, error, refetch } = useSuggestionContent(); const { data, isLoading, isError, error, refetch } = useSuggestionContent();
const suggestions = data?.suggestions;
const hasRecentSearches = (suggestions?.recent_searches?.length ?? 0) > 0;
if (isError) { if (isError) {
return ( return (
@@ -29,11 +34,45 @@ export const SuggestionContent = () => {
); );
} }
const suggestions = data?.suggestions;
return ( return (
<div className={blockMenuContainerStyle}> <div className={blockMenuContainerStyle}>
<div className="w-full space-y-6 pb-4"> <div className="w-full space-y-6 pb-4">
{/* Recent searches */}
{hasRecentSearches && (
<div className="space-y-2.5 px-4">
<p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800">
Recent searches
</p>
<HorizontalScroll
wrapperClassName="-mx-8"
scrollContainerClassName="flex gap-2 overflow-x-auto px-8 [scrollbar-width:none] [-ms-overflow-style:'none'] [&::-webkit-scrollbar]:hidden"
dependencyList={[
suggestions?.recent_searches?.length ?? 0,
isLoading,
]}
>
{!isLoading && suggestions
? suggestions.recent_searches.map((entry, index) => (
<SearchHistoryChip
key={entry.search_id || `${entry.search_query}-${index}`}
content={entry.search_query || "Untitled search"}
onClick={() => {
setSearchQuery(entry.search_query || "");
setSearchId(entry.search_id || undefined);
}}
/>
))
: Array(3)
.fill(0)
.map((_, index) => (
<SearchHistoryChip.Skeleton
key={`recent-search-skeleton-${index}`}
/>
))}
</HorizontalScroll>
</div>
)}
{/* Integrations */} {/* Integrations */}
<div className="space-y-2.5 px-4"> <div className="space-y-2.5 px-4">
<p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800"> <p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800">

View File

@@ -24,11 +24,13 @@ import { useNewAgentLibraryView } from "./useNewAgentLibraryView";
export function NewAgentLibraryView() { export function NewAgentLibraryView() {
const { const {
agent,
hasAnyItems,
ready,
error,
agentId, agentId,
agent,
ready,
activeTemplate,
isTemplateLoading,
error,
hasAnyItems,
activeItem, activeItem,
sidebarLoading, sidebarLoading,
activeTab, activeTab,
@@ -36,6 +38,9 @@ export function NewAgentLibraryView() {
handleSelectRun, handleSelectRun,
handleCountsChange, handleCountsChange,
handleClearSelectedRun, handleClearSelectedRun,
onRunInitiated,
onTriggerSetup,
onScheduleCreated,
} = useNewAgentLibraryView(); } = useNewAgentLibraryView();
if (error) { if (error) {
@@ -65,14 +70,19 @@ export function NewAgentLibraryView() {
/> />
</div> </div>
<div className="flex min-h-0 flex-1"> <div className="flex min-h-0 flex-1">
<EmptyTasks agent={agent} /> <EmptyTasks
agent={agent}
onRun={onRunInitiated}
onTriggerSetup={onTriggerSetup}
onScheduleCreated={onScheduleCreated}
/>
</div> </div>
</div> </div>
); );
} }
return ( return (
<div className="ml-4 grid h-full grid-cols-1 gap-0 pt-3 md:gap-4 lg:grid-cols-[25%_70%]"> <div className="mx-4 grid h-full grid-cols-1 gap-0 pt-3 md:ml-4 md:mr-0 md:gap-4 lg:grid-cols-[25%_70%]">
<SectionWrap className="mb-3 block"> <SectionWrap className="mb-3 block">
<div <div
className={cn( className={cn(
@@ -82,16 +92,21 @@ export function NewAgentLibraryView() {
> >
<RunAgentModal <RunAgentModal
triggerSlot={ triggerSlot={
<Button variant="primary" size="large" className="w-full"> <Button
variant="primary"
size="large"
className="w-full"
disabled={isTemplateLoading && activeTab === "templates"}
>
<PlusIcon size={20} /> New task <PlusIcon size={20} /> New task
</Button> </Button>
} }
agent={agent} agent={agent}
agentId={agent.id.toString()} onRunCreated={onRunInitiated}
onRunCreated={(execution) => handleSelectRun(execution.id, "runs")} onScheduleCreated={onScheduleCreated}
onScheduleCreated={(schedule) => onTriggerSetup={onTriggerSetup}
handleSelectRun(schedule.id, "scheduled") initialInputValues={activeTemplate?.inputs}
} initialInputCredentials={activeTemplate?.credentials}
/> />
</div> </div>
@@ -151,7 +166,12 @@ export function NewAgentLibraryView() {
</SelectedViewLayout> </SelectedViewLayout>
) : ( ) : (
<SelectedViewLayout agentName={agent.name} agentId={agent.id}> <SelectedViewLayout agentName={agent.name} agentId={agent.id}>
<EmptyTasks agent={agent} /> <EmptyTasks
agent={agent}
onRun={onRunInitiated}
onTriggerSetup={onTriggerSetup}
onScheduleCreated={onScheduleCreated}
/>
</SelectedViewLayout> </SelectedViewLayout>
)} )}
</div> </div>

View File

@@ -1,7 +1,10 @@
"use client"; "use client";
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types"; import type {
BlockIOSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import { CredentialsInput } from "../CredentialsInputs/CredentialsInputs"; import { CredentialsInput } from "../CredentialsInputs/CredentialsInputs";
import { import {
getAgentCredentialsFields, getAgentCredentialsFields,
@@ -20,13 +23,21 @@ export function AgentInputsReadOnly({
inputs, inputs,
credentialInputs, credentialInputs,
}: Props) { }: Props) {
const fields = getAgentInputFields(agent); const inputFields = getAgentInputFields(agent);
const credentialFields = getAgentCredentialsFields(agent); const credentialFieldEntries = Object.entries(
const inputEntries = Object.entries(fields); getAgentCredentialsFields(agent),
const credentialEntries = Object.entries(credentialFields); );
const hasInputs = inputs && inputEntries.length > 0; // Take actual input entries as leading; augment with schema from input fields.
const hasCredentials = credentialInputs && credentialEntries.length > 0; // TODO: ensure consistent ordering.
const inputEntries =
inputs &&
Object.entries(inputs).map<[string, [BlockIOSubSchema | undefined, any]]>(
([k, v]) => [k, [inputFields[k], v]],
);
const hasInputs = inputEntries && inputEntries.length > 0;
const hasCredentials = credentialInputs && credentialFieldEntries.length > 0;
if (!hasInputs && !hasCredentials) { if (!hasInputs && !hasCredentials) {
return <div className="text-neutral-600">No input for this run.</div>; return <div className="text-neutral-600">No input for this run.</div>;
@@ -37,11 +48,13 @@ export function AgentInputsReadOnly({
{/* Regular inputs */} {/* Regular inputs */}
{hasInputs && ( {hasInputs && (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
{inputEntries.map(([key, sub]) => ( {inputEntries.map(([key, [schema, value]]) => (
<div key={key} className="flex flex-col gap-1.5"> <div key={key} className="flex flex-col gap-1.5">
<label className="text-sm font-medium">{sub?.title || key}</label> <label className="text-sm font-medium">
{schema?.title || key}
</label>
<p className="whitespace-pre-wrap break-words text-sm text-neutral-700"> <p className="whitespace-pre-wrap break-words text-sm text-neutral-700">
{renderValue((inputs as Record<string, any>)[key])} {renderValue(value)}
</p> </p>
</div> </div>
))} ))}
@@ -52,7 +65,7 @@ export function AgentInputsReadOnly({
{hasCredentials && ( {hasCredentials && (
<div className="flex flex-col gap-6"> <div className="flex flex-col gap-6">
{hasInputs && <div className="border-t border-neutral-200 pt-4" />} {hasInputs && <div className="border-t border-neutral-200 pt-4" />}
{credentialEntries.map(([key, inputSubSchema]) => { {credentialFieldEntries.map(([key, inputSubSchema]) => {
const credential = credentialInputs![key]; const credential = credentialInputs![key];
if (!credential) return null; if (!credential) return null;

View File

@@ -13,7 +13,8 @@ export function getCredentialTypeDisplayName(type: string): string {
} }
export function getAgentInputFields(agent: LibraryAgent): Record<string, any> { export function getAgentInputFields(agent: LibraryAgent): Record<string, any> {
const schema = agent.input_schema as unknown as { const schema = (agent.trigger_setup_info?.config_schema ??
agent.input_schema) as unknown as {
properties?: Record<string, any>; properties?: Record<string, any>;
} | null; } | null;
if (!schema || !schema.properties) return {}; if (!schema || !schema.properties) return {};

View File

@@ -3,6 +3,7 @@
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo"; import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
import { Button } from "@/components/atoms/Button/Button"; import { Button } from "@/components/atoms/Button/Button";
import { import {
Tooltip, Tooltip,
@@ -22,16 +23,20 @@ import { useAgentRunModal } from "./useAgentRunModal";
interface Props { interface Props {
triggerSlot: React.ReactNode; triggerSlot: React.ReactNode;
agent: LibraryAgent; agent: LibraryAgent;
agentId: string; initialInputValues?: Record<string, any>;
agentVersion?: number; initialInputCredentials?: Record<string, any>;
onRunCreated?: (execution: GraphExecutionMeta) => void; onRunCreated?: (execution: GraphExecutionMeta) => void;
onTriggerSetup?: (preset: LibraryAgentPreset) => void;
onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void; onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void;
} }
export function RunAgentModal({ export function RunAgentModal({
triggerSlot, triggerSlot,
agent, agent,
initialInputValues,
initialInputCredentials,
onRunCreated, onRunCreated,
onTriggerSetup,
onScheduleCreated, onScheduleCreated,
}: Props) { }: Props) {
const { const {
@@ -71,6 +76,9 @@ export function RunAgentModal({
handleRun, handleRun,
} = useAgentRunModal(agent, { } = useAgentRunModal(agent, {
onRun: onRunCreated, onRun: onRunCreated,
onSetupTrigger: onTriggerSetup,
initialInputValues,
initialInputCredentials,
}); });
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false); const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
@@ -79,6 +87,8 @@ export function RunAgentModal({
Object.keys(agentInputFields || {}).length > 0 || Object.keys(agentInputFields || {}).length > 0 ||
Object.keys(agentCredentialsInputFields || {}).length > 0; Object.keys(agentCredentialsInputFields || {}).length > 0;
const isTriggerRunType = defaultRunType.includes("trigger");
function handleInputChange(key: string, value: string) { function handleInputChange(key: string, value: string) {
setInputValues((prev) => ({ setInputValues((prev) => ({
...prev, ...prev,
@@ -153,7 +163,7 @@ export function RunAgentModal({
<Dialog.Footer className="mt-6 bg-white pt-4"> <Dialog.Footer className="mt-6 bg-white pt-4">
<div className="flex items-center justify-end gap-3"> <div className="flex items-center justify-end gap-3">
{!allRequiredInputsAreSet ? ( {isTriggerRunType ? null : !allRequiredInputsAreSet ? (
<TooltipProvider> <TooltipProvider>
<Tooltip> <Tooltip>
<TooltipTrigger asChild> <TooltipTrigger asChild>

View File

@@ -26,7 +26,8 @@ export function ModalRunSection() {
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
{defaultRunType === "automatic-trigger" ? ( {defaultRunType === "automatic-trigger" ||
defaultRunType === "manual-trigger" ? (
<ModalSection <ModalSection
title="Task Trigger" title="Task Trigger"
subtitle="Set up a trigger for the agent to run this task automatically" subtitle="Set up a trigger for the agent to run this task automatically"

View File

@@ -24,7 +24,8 @@ export function RunActions({
disabled={!isRunReady || isExecuting || isSettingUpTrigger} disabled={!isRunReady || isExecuting || isSettingUpTrigger}
loading={isExecuting || isSettingUpTrigger} loading={isExecuting || isSettingUpTrigger}
> >
{defaultRunType === "automatic-trigger" {defaultRunType === "automatic-trigger" ||
defaultRunType === "manual-trigger"
? "Set up Trigger" ? "Set up Trigger"
: "Start Task"} : "Start Task"}
</Button> </Button>

View File

@@ -1,12 +1,11 @@
import { import {
getGetV1ListGraphExecutionsInfiniteQueryOptions, getGetV1ListGraphExecutionsQueryKey,
usePostV1ExecuteGraphAgent, usePostV1ExecuteGraphAgent,
} from "@/app/api/__generated__/endpoints/graphs/graphs"; } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { import {
getGetV2ListPresetsQueryKey, getGetV2ListPresetsQueryKey,
usePostV2SetupTrigger, usePostV2SetupTrigger,
} from "@/app/api/__generated__/endpoints/presets/presets"; } from "@/app/api/__generated__/endpoints/presets/presets";
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset"; import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
@@ -14,7 +13,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { isEmpty } from "@/lib/utils"; import { isEmpty } from "@/lib/utils";
import { analytics } from "@/services/analytics"; import { analytics } from "@/services/analytics";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useCallback, useMemo, useState } from "react"; import { useCallback, useEffect, useMemo, useState } from "react";
import { showExecutionErrorToast } from "./errorHelpers"; import { showExecutionErrorToast } from "./errorHelpers";
export type RunVariant = export type RunVariant =
@@ -25,8 +24,9 @@ export type RunVariant =
interface UseAgentRunModalCallbacks { interface UseAgentRunModalCallbacks {
onRun?: (execution: GraphExecutionMeta) => void; onRun?: (execution: GraphExecutionMeta) => void;
onCreateSchedule?: (schedule: GraphExecutionJobInfo) => void;
onSetupTrigger?: (preset: LibraryAgentPreset) => void; onSetupTrigger?: (preset: LibraryAgentPreset) => void;
initialInputValues?: Record<string, any>;
initialInputCredentials?: Record<string, any>;
} }
export function useAgentRunModal( export function useAgentRunModal(
@@ -36,18 +36,28 @@ export function useAgentRunModal(
const { toast } = useToast(); const { toast } = useToast();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
const [inputValues, setInputValues] = useState<Record<string, any>>({}); const [inputValues, setInputValues] = useState<Record<string, any>>(
callbacks?.initialInputValues || {},
);
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>( const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
{}, callbacks?.initialInputCredentials || {},
); );
const [presetName, setPresetName] = useState<string>(""); const [presetName, setPresetName] = useState<string>("");
const [presetDescription, setPresetDescription] = useState<string>(""); const [presetDescription, setPresetDescription] = useState<string>("");
// Determine the default run type based on agent capabilities // Determine the default run type based on agent capabilities
const defaultRunType: RunVariant = agent.has_external_trigger const defaultRunType: RunVariant = agent.trigger_setup_info
? agent.trigger_setup_info.credentials_input_name
? "automatic-trigger" ? "automatic-trigger"
: "manual-trigger"
: "manual"; : "manual";
// Update input values/credentials if template is selected/unselected
useEffect(() => {
setInputValues(callbacks?.initialInputValues || {});
setInputCredentials(callbacks?.initialInputCredentials || {});
}, [callbacks?.initialInputValues, callbacks?.initialInputCredentials]);
// API mutations // API mutations
const executeGraphMutation = usePostV1ExecuteGraphAgent({ const executeGraphMutation = usePostV1ExecuteGraphAgent({
mutation: { mutation: {
@@ -56,13 +66,11 @@ export function useAgentRunModal(
toast({ toast({
title: "Agent execution started", title: "Agent execution started",
}); });
callbacks?.onRun?.(response.data as unknown as GraphExecutionMeta);
// Invalidate runs list for this graph // Invalidate runs list for this graph
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions( queryKey: getGetV1ListGraphExecutionsQueryKey(agent.graph_id),
agent.graph_id,
).queryKey,
}); });
callbacks?.onRun?.(response.data);
analytics.sendDatafastEvent("run_agent", { analytics.sendDatafastEvent("run_agent", {
name: agent.name, name: agent.name,
id: agent.graph_id, id: agent.graph_id,
@@ -81,17 +89,15 @@ export function useAgentRunModal(
const setupTriggerMutation = usePostV2SetupTrigger({ const setupTriggerMutation = usePostV2SetupTrigger({
mutation: { mutation: {
onSuccess: (response: any) => { onSuccess: (response) => {
if (response.status === 200) { if (response.status === 200) {
toast({ toast({
title: "Trigger setup complete", title: "Trigger setup complete",
}); });
callbacks?.onSetupTrigger?.(response.data);
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV2ListPresetsQueryKey({ queryKey: getGetV2ListPresetsQueryKey({ graph_id: agent.graph_id }),
graph_id: agent.graph_id,
}),
}); });
callbacks?.onSetupTrigger?.(response.data);
setIsOpen(false); setIsOpen(false);
} }
}, },
@@ -105,11 +111,13 @@ export function useAgentRunModal(
}, },
}); });
// Input schema validation // Input schema validation (use trigger schema for triggered agents)
const agentInputSchema = useMemo( const agentInputSchema = useMemo(() => {
() => agent.input_schema || { properties: {}, required: [] }, if (agent.trigger_setup_info?.config_schema) {
[agent.input_schema], return agent.trigger_setup_info.config_schema;
); }
return agent.input_schema || { properties: {}, required: [] };
}, [agent.input_schema, agent.trigger_setup_info]);
const agentInputFields = useMemo(() => { const agentInputFields = useMemo(() => {
if ( if (
@@ -205,7 +213,10 @@ export function useAgentRunModal(
return; return;
} }
if (defaultRunType === "automatic-trigger") { if (
defaultRunType === "automatic-trigger" ||
defaultRunType === "manual-trigger"
) {
// Setup trigger // Setup trigger
if (!presetName.trim()) { if (!presetName.trim()) {
toast({ toast({
@@ -262,7 +273,7 @@ export function useAgentRunModal(
setIsOpen, setIsOpen,
// Run mode // Run mode
defaultRunType, defaultRunType: defaultRunType as RunVariant,
// Form: regular inputs // Form: regular inputs
inputValues, inputValues,

View File

@@ -1,17 +1,58 @@
"use client";
import { getV1GetGraphVersion } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
import { Button } from "@/components/atoms/Button/Button"; import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text"; import { Text } from "@/components/atoms/Text/Text";
import { ShowMoreText } from "@/components/molecules/ShowMoreText/ShowMoreText"; import { ShowMoreText } from "@/components/molecules/ShowMoreText/ShowMoreText";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { exportAsJSONFile } from "@/lib/utils";
import { formatDate } from "@/lib/utils/time"; import { formatDate } from "@/lib/utils/time";
import Link from "next/link";
import { RunAgentModal } from "../modals/RunAgentModal/RunAgentModal"; import { RunAgentModal } from "../modals/RunAgentModal/RunAgentModal";
import { RunDetailCard } from "../selected-views/RunDetailCard/RunDetailCard"; import { RunDetailCard } from "../selected-views/RunDetailCard/RunDetailCard";
import { EmptyTasksIllustration } from "./EmptyTasksIllustration"; import { EmptyTasksIllustration } from "./EmptyTasksIllustration";
type Props = { type Props = {
agent: LibraryAgent; agent: LibraryAgent;
onRun?: (run: GraphExecutionMeta) => void;
onTriggerSetup?: (preset: LibraryAgentPreset) => void;
onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void;
}; };
export function EmptyTasks({ agent }: Props) { export function EmptyTasks({
agent,
onRun,
onTriggerSetup,
onScheduleCreated,
}: Props) {
const { toast } = useToast();
async function handleExport() {
try {
const res = await getV1GetGraphVersion(
agent.graph_id,
agent.graph_version,
{ for_export: true },
);
if (res.status === 200) {
const filename = `${agent.name}_v${agent.graph_version}.json`;
exportAsJSONFile(res.data as any, filename);
toast({ title: "Agent exported" });
} else {
toast({ title: "Failed to export agent", variant: "destructive" });
}
} catch (e: any) {
toast({
title: "Failed to export agent",
description: e?.message,
variant: "destructive",
});
}
}
const isPublished = Boolean(agent.marketplace_listing); const isPublished = Boolean(agent.marketplace_listing);
const createdAt = formatDate(agent.created_at); const createdAt = formatDate(agent.created_at);
const updatedAt = formatDate(agent.updated_at); const updatedAt = formatDate(agent.updated_at);
@@ -45,7 +86,9 @@ export function EmptyTasks({ agent }: Props) {
</Button> </Button>
} }
agent={agent} agent={agent}
agentId={agent.id.toString()} onRunCreated={onRun}
onTriggerSetup={onTriggerSetup}
onScheduleCreated={onScheduleCreated}
/> />
</div> </div>
</div> </div>
@@ -93,10 +136,15 @@ export function EmptyTasks({ agent }: Props) {
) : null} ) : null}
</div> </div>
<div className="mt-4 flex items-center gap-2"> <div className="mt-4 flex items-center gap-2">
<Button variant="secondary" size="small"> <Button variant="secondary" size="small" asChild>
<Link
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
target="_blank"
>
Edit agent Edit agent
</Link>
</Button> </Button>
<Button variant="secondary" size="small"> <Button variant="secondary" size="small" onClick={handleExport}>
Export agent to file Export agent to file
</Button> </Button>
</div> </div>

View File

@@ -0,0 +1,14 @@
import { cn } from "@/lib/utils";
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../helpers";
type Props = {
children: React.ReactNode;
};
export function AnchorLinksWrap({ children }: Props) {
return (
<div className={cn(AGENT_LIBRARY_SECTION_PADDING_X, "hidden lg:block")}>
<nav className="flex gap-8 px-3 pb-1">{children}</nav>
</div>
);
}

View File

@@ -166,7 +166,7 @@ function renderMarkdown(
className="prose prose-sm dark:prose-invert max-w-none" className="prose prose-sm dark:prose-invert max-w-none"
remarkPlugins={[ remarkPlugins={[
remarkGfm, // GitHub Flavored Markdown (tables, task lists, strikethrough) remarkGfm, // GitHub Flavored Markdown (tables, task lists, strikethrough)
remarkMath, // Math support for LaTeX [remarkMath, { singleDollarTextMath: false }], // Math support for LaTeX
]} ]}
rehypePlugins={[ rehypePlugins={[
rehypeKatex, // Render math with KaTeX rehypeKatex, // Render math with KaTeX

View File

@@ -0,0 +1,11 @@
type Props = {
children: React.ReactNode;
};
export function SelectedActionsWrap({ children }: Props) {
return (
<div className="my-0 ml-4 flex flex-row items-center gap-3 lg:mx-0 lg:my-4 lg:flex-col">
{children}
</div>
);
}

View File

@@ -13,10 +13,11 @@ import {
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList"; import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList";
import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews"; import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews";
import { isLargeScreen, useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { InfoIcon } from "@phosphor-icons/react"; import { InfoIcon } from "@phosphor-icons/react";
import { useEffect } from "react"; import { useEffect } from "react";
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../../helpers";
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly"; import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
import { AnchorLinksWrap } from "../AnchorLinksWrap";
import { LoadingSelectedContent } from "../LoadingSelectedContent"; import { LoadingSelectedContent } from "../LoadingSelectedContent";
import { RunDetailCard } from "../RunDetailCard/RunDetailCard"; import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader"; import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
@@ -46,6 +47,9 @@ export function SelectedRunView({
const { run, preset, isLoading, responseError, httpError } = const { run, preset, isLoading, responseError, httpError } =
useSelectedRunView(agent.graph_id, runId); useSelectedRunView(agent.graph_id, runId);
const breakpoint = useBreakpoint();
const isLgScreenUp = isLargeScreen(breakpoint);
const { const {
pendingReviews, pendingReviews,
isLoading: reviewsLoading, isLoading: reviewsLoading,
@@ -89,6 +93,15 @@ export function SelectedRunView({
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<RunDetailHeader agent={agent} run={run} /> <RunDetailHeader agent={agent} run={run} />
{!isLgScreenUp ? (
<SelectedRunActions
agent={agent}
run={run}
onSelectRun={onSelectRun}
onClearSelectedRun={onClearSelectedRun}
/>
) : null}
{preset && {preset &&
agent.trigger_setup_info && agent.trigger_setup_info &&
preset.webhook_id && preset.webhook_id &&
@@ -100,8 +113,7 @@ export function SelectedRunView({
)} )}
{/* Navigation Links */} {/* Navigation Links */}
<div className={AGENT_LIBRARY_SECTION_PADDING_X}> <AnchorLinksWrap>
<nav className="flex gap-8 px-3 pb-1">
{withSummary && ( {withSummary && (
<button <button
onClick={() => scrollToSection("summary")} onClick={() => scrollToSection("summary")}
@@ -130,8 +142,7 @@ export function SelectedRunView({
Reviews ({pendingReviews.length}) Reviews ({pendingReviews.length})
</button> </button>
)} )}
</nav> </AnchorLinksWrap>
</div>
{/* Summary Section */} {/* Summary Section */}
{withSummary && ( {withSummary && (
@@ -187,8 +198,8 @@ export function SelectedRunView({
<RunDetailCard title="Your input"> <RunDetailCard title="Your input">
<AgentInputsReadOnly <AgentInputsReadOnly
agent={agent} agent={agent}
inputs={(run as any)?.inputs} inputs={run?.inputs}
credentialInputs={(run as any)?.credential_inputs} credentialInputs={run?.credential_inputs}
/> />
</RunDetailCard> </RunDetailCard>
</div> </div>
@@ -216,7 +227,8 @@ export function SelectedRunView({
</div> </div>
</SelectedViewLayout> </SelectedViewLayout>
</div> </div>
<div className="-mt-2 max-w-[3.75rem] flex-shrink-0"> {isLgScreenUp ? (
<div className="max-w-[3.75rem] flex-shrink-0">
<SelectedRunActions <SelectedRunActions
agent={agent} agent={agent}
run={run} run={run}
@@ -224,6 +236,7 @@ export function SelectedRunView({
onClearSelectedRun={onClearSelectedRun} onClearSelectedRun={onClearSelectedRun}
/> />
</div> </div>
) : null}
</div> </div>
); );
} }

View File

@@ -12,6 +12,7 @@ import {
StopIcon, StopIcon,
} from "@phosphor-icons/react"; } from "@phosphor-icons/react";
import { AgentActionsDropdown } from "../../../AgentActionsDropdown"; import { AgentActionsDropdown } from "../../../AgentActionsDropdown";
import { SelectedActionsWrap } from "../../../SelectedActionsWrap";
import { ShareRunButton } from "../../../ShareRunButton/ShareRunButton"; import { ShareRunButton } from "../../../ShareRunButton/ShareRunButton";
import { CreateTemplateModal } from "../CreateTemplateModal/CreateTemplateModal"; import { CreateTemplateModal } from "../CreateTemplateModal/CreateTemplateModal";
import { useSelectedRunActions } from "./useSelectedRunActions"; import { useSelectedRunActions } from "./useSelectedRunActions";
@@ -19,13 +20,18 @@ import { useSelectedRunActions } from "./useSelectedRunActions";
type Props = { type Props = {
agent: LibraryAgent; agent: LibraryAgent;
run: GraphExecution | undefined; run: GraphExecution | undefined;
scheduleRecurrence?: string;
onSelectRun?: (id: string) => void; onSelectRun?: (id: string) => void;
onClearSelectedRun?: () => void; onClearSelectedRun?: () => void;
}; };
export function SelectedRunActions(props: Props) { export function SelectedRunActions({
agent,
run,
onSelectRun,
onClearSelectedRun,
}: Props) {
const { const {
canRunManually,
handleRunAgain, handleRunAgain,
handleStopRun, handleStopRun,
isRunningAgain, isRunningAgain,
@@ -36,21 +42,20 @@ export function SelectedRunActions(props: Props) {
isCreateTemplateModalOpen, isCreateTemplateModalOpen,
setIsCreateTemplateModalOpen, setIsCreateTemplateModalOpen,
} = useSelectedRunActions({ } = useSelectedRunActions({
agentGraphId: props.agent.graph_id, agentGraphId: agent.graph_id,
run: props.run, run: run,
agent: props.agent, agent: agent,
onSelectRun: props.onSelectRun, onSelectRun: onSelectRun,
onClearSelectedRun: props.onClearSelectedRun,
}); });
const shareExecutionResultsEnabled = useGetFlag(Flag.SHARE_EXECUTION_RESULTS); const shareExecutionResultsEnabled = useGetFlag(Flag.SHARE_EXECUTION_RESULTS);
const isRunning = props.run?.status === "RUNNING"; const isRunning = run?.status === "RUNNING";
if (!props.run || !props.agent) return null; if (!run || !agent) return null;
return ( return (
<div className="my-4 flex flex-col items-center gap-3"> <SelectedActionsWrap>
{!isRunning ? ( {canRunManually && !isRunning ? (
<Button <Button
variant="icon" variant="icon"
size="icon" size="icon"
@@ -102,17 +107,15 @@ export function SelectedRunActions(props: Props) {
) : null} ) : null}
{shareExecutionResultsEnabled && ( {shareExecutionResultsEnabled && (
<ShareRunButton <ShareRunButton
graphId={props.agent.graph_id} graphId={agent.graph_id}
executionId={props.run.id} executionId={run.id}
isShared={props.run.is_shared} isShared={run.is_shared}
shareToken={props.run.share_token} shareToken={run.share_token}
/> />
)} )}
<FloatingSafeModeToggle <FloatingSafeModeToggle graph={agent} variant="white" fullWidth={false} />
graph={props.agent} {canRunManually && (
variant="white" <>
fullWidth={false}
/>
<Button <Button
variant="icon" variant="icon"
size="icon" size="icon"
@@ -122,18 +125,20 @@ export function SelectedRunActions(props: Props) {
> >
<CardsThreeIcon weight="bold" size={18} className="text-zinc-700" /> <CardsThreeIcon weight="bold" size={18} className="text-zinc-700" />
</Button> </Button>
<AgentActionsDropdown
agent={props.agent}
run={props.run}
agentGraphId={props.agent.graph_id}
onClearSelectedRun={props.onClearSelectedRun}
/>
<CreateTemplateModal <CreateTemplateModal
isOpen={isCreateTemplateModalOpen} isOpen={isCreateTemplateModalOpen}
onClose={() => setIsCreateTemplateModalOpen(false)} onClose={() => setIsCreateTemplateModalOpen(false)}
onCreate={handleCreateTemplate} onCreate={handleCreateTemplate}
run={props.run} run={run}
/> />
</div> </>
)}
<AgentActionsDropdown
agent={agent}
run={run}
agentGraphId={agent.graph_id}
onClearSelectedRun={onClearSelectedRun}
/>
</SelectedActionsWrap>
); );
} }

View File

@@ -15,15 +15,19 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useState } from "react"; import { useState } from "react";
interface Args { interface Params {
agentGraphId: string; agentGraphId: string;
run?: GraphExecution; run?: GraphExecution;
agent?: LibraryAgent; agent?: LibraryAgent;
onSelectRun?: (id: string) => void; onSelectRun?: (id: string) => void;
onClearSelectedRun?: () => void;
} }
export function useSelectedRunActions(args: Args) { export function useSelectedRunActions({
agentGraphId,
run,
agent,
onSelectRun,
}: Params) {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const { toast } = useToast(); const { toast } = useToast();
@@ -31,8 +35,9 @@ export function useSelectedRunActions(args: Args) {
const [isCreateTemplateModalOpen, setIsCreateTemplateModalOpen] = const [isCreateTemplateModalOpen, setIsCreateTemplateModalOpen] =
useState(false); useState(false);
const canStop = const canStop = run?.status === "RUNNING" || run?.status === "QUEUED";
args.run?.status === "RUNNING" || args.run?.status === "QUEUED";
const canRunManually = !agent?.trigger_setup_info;
const { mutateAsync: stopRun, isPending: isStopping } = const { mutateAsync: stopRun, isPending: isStopping } =
usePostV1StopGraphExecution(); usePostV1StopGraphExecution();
@@ -46,16 +51,16 @@ export function useSelectedRunActions(args: Args) {
async function handleStopRun() { async function handleStopRun() {
try { try {
await stopRun({ await stopRun({
graphId: args.run?.graph_id ?? "", graphId: run?.graph_id ?? "",
graphExecId: args.run?.id ?? "", graphExecId: run?.id ?? "",
}); });
toast({ title: "Run stopped" }); toast({ title: "Run stopped" });
await queryClient.invalidateQueries({ await queryClient.invalidateQueries({
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions( queryKey:
args.agentGraphId, getGetV1ListGraphExecutionsInfiniteQueryOptions(agentGraphId)
).queryKey, .queryKey,
}); });
} catch (error: unknown) { } catch (error: unknown) {
toast({ toast({
@@ -70,7 +75,7 @@ export function useSelectedRunActions(args: Args) {
} }
async function handleRunAgain() { async function handleRunAgain() {
if (!args.run) { if (!run) {
toast({ toast({
title: "Run not found", title: "Run not found",
description: "Run not found", description: "Run not found",
@@ -83,11 +88,11 @@ export function useSelectedRunActions(args: Args) {
toast({ title: "Run started" }); toast({ title: "Run started" });
const res = await executeRun({ const res = await executeRun({
graphId: args.run.graph_id, graphId: run.graph_id,
graphVersion: args.run.graph_version, graphVersion: run.graph_version,
data: { data: {
inputs: args.run.inputs || {}, inputs: run.inputs || {},
credentials_inputs: args.run.credential_inputs || {}, credentials_inputs: run.credential_inputs || {},
source: "library", source: "library",
}, },
}); });
@@ -95,12 +100,12 @@ export function useSelectedRunActions(args: Args) {
const newRunId = res?.status === 200 ? (res?.data?.id ?? "") : ""; const newRunId = res?.status === 200 ? (res?.data?.id ?? "") : "";
await queryClient.invalidateQueries({ await queryClient.invalidateQueries({
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions( queryKey:
args.agentGraphId, getGetV1ListGraphExecutionsInfiniteQueryOptions(agentGraphId)
).queryKey, .queryKey,
}); });
if (newRunId && args.onSelectRun) args.onSelectRun(newRunId); if (newRunId && onSelectRun) onSelectRun(newRunId);
} catch (error: unknown) { } catch (error: unknown) {
toast({ toast({
title: "Failed to start run", title: "Failed to start run",
@@ -118,7 +123,7 @@ export function useSelectedRunActions(args: Args) {
} }
async function handleCreateTemplate(name: string, description: string) { async function handleCreateTemplate(name: string, description: string) {
if (!args.run) { if (!run) {
toast({ toast({
title: "Run not found", title: "Run not found",
description: "Cannot create template from missing run", description: "Cannot create template from missing run",
@@ -132,7 +137,7 @@ export function useSelectedRunActions(args: Args) {
data: { data: {
name, name,
description, description,
graph_execution_id: args.run.id, graph_execution_id: run.id,
}, },
}); });
@@ -141,10 +146,10 @@ export function useSelectedRunActions(args: Args) {
title: "Template created", title: "Template created",
}); });
if (args.agent) { if (agent) {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV2ListPresetsQueryKey({ queryKey: getGetV2ListPresetsQueryKey({
graph_id: args.agent.graph_id, graph_id: agent.graph_id,
}), }),
}); });
} }
@@ -164,8 +169,8 @@ export function useSelectedRunActions(args: Args) {
} }
// Open in builder URL helper // Open in builder URL helper
const openInBuilderHref = args.run const openInBuilderHref = run
? `/build?flowID=${args.run.graph_id}&flowVersion=${args.run.graph_version}&flowExecutionID=${args.run.id}` ? `/build?flowID=${run.graph_id}&flowVersion=${run.graph_version}&flowExecutionID=${run.id}`
: undefined; : undefined;
return { return {
@@ -173,6 +178,7 @@ export function useSelectedRunActions(args: Args) {
showDeleteDialog, showDeleteDialog,
canStop, canStop,
isStopping, isStopping,
canRunManually,
isRunningAgain, isRunningAgain,
handleShowDeleteDialog, handleShowDeleteDialog,
handleStopRun, handleStopRun,

View File

@@ -6,9 +6,10 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
import { Text } from "@/components/atoms/Text/Text"; import { Text } from "@/components/atoms/Text/Text";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { humanizeCronExpression } from "@/lib/cron-expression-utils"; import { humanizeCronExpression } from "@/lib/cron-expression-utils";
import { isLargeScreen, useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { formatInTimezone, getTimezoneDisplayName } from "@/lib/timezone-utils"; import { formatInTimezone, getTimezoneDisplayName } from "@/lib/timezone-utils";
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../../helpers";
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly"; import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
import { AnchorLinksWrap } from "../AnchorLinksWrap";
import { LoadingSelectedContent } from "../LoadingSelectedContent"; import { LoadingSelectedContent } from "../LoadingSelectedContent";
import { RunDetailCard } from "../RunDetailCard/RunDetailCard"; import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader"; import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
@@ -41,6 +42,9 @@ export function SelectedScheduleView({
}, },
}); });
const breakpoint = useBreakpoint();
const isLgScreenUp = isLargeScreen(breakpoint);
function scrollToSection(id: string) { function scrollToSection(id: string) {
const element = document.getElementById(id); const element = document.getElementById(id);
if (element) { if (element) {
@@ -83,7 +87,6 @@ export function SelectedScheduleView({
<div className="flex min-h-0 min-w-0 flex-1 flex-col"> <div className="flex min-h-0 min-w-0 flex-1 flex-col">
<SelectedViewLayout agentName={agent.name} agentId={agent.id}> <SelectedViewLayout agentName={agent.name} agentId={agent.id}>
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
<div className="flex w-full items-center justify-between">
<div className="flex w-full flex-col gap-0"> <div className="flex w-full flex-col gap-0">
<RunDetailHeader <RunDetailHeader
agent={agent} agent={agent}
@@ -94,12 +97,19 @@ export function SelectedScheduleView({
: undefined : undefined
} }
/> />
{schedule && !isLgScreenUp ? (
<div className="mt-4">
<SelectedScheduleActions
agent={agent}
scheduleId={schedule.id}
onDeleted={onClearSelectedRun}
/>
</div> </div>
) : null}
</div> </div>
{/* Navigation Links */} {/* Navigation Links */}
<div className={AGENT_LIBRARY_SECTION_PADDING_X}> <AnchorLinksWrap>
<nav className="flex gap-8 px-3 pb-1">
<button <button
onClick={() => scrollToSection("schedule")} onClick={() => scrollToSection("schedule")}
className={anchorStyles} className={anchorStyles}
@@ -112,8 +122,7 @@ export function SelectedScheduleView({
> >
Your input Your input
</button> </button>
</nav> </AnchorLinksWrap>
</div>
{/* Schedule Section */} {/* Schedule Section */}
<div id="schedule" className="scroll-mt-4"> <div id="schedule" className="scroll-mt-4">
@@ -172,10 +181,6 @@ export function SelectedScheduleView({
<div id="input" className="scroll-mt-4"> <div id="input" className="scroll-mt-4">
<RunDetailCard title="Your input"> <RunDetailCard title="Your input">
<div className="relative"> <div className="relative">
{/* {// TODO: re-enable edit inputs modal once the API supports it */}
{/* {schedule && Object.keys(schedule.input_data).length > 0 && (
<EditInputsModal agent={agent} schedule={schedule} />
)} */}
<AgentInputsReadOnly <AgentInputsReadOnly
agent={agent} agent={agent}
inputs={schedule?.input_data} inputs={schedule?.input_data}
@@ -187,8 +192,8 @@ export function SelectedScheduleView({
</div> </div>
</SelectedViewLayout> </SelectedViewLayout>
</div> </div>
{schedule ? ( {schedule && isLgScreenUp ? (
<div className="-mt-2 max-w-[3.75rem] flex-shrink-0"> <div className="max-w-[3.75rem] flex-shrink-0">
<SelectedScheduleActions <SelectedScheduleActions
agent={agent} agent={agent}
scheduleId={schedule.id} scheduleId={schedule.id}

View File

@@ -3,6 +3,7 @@ import { Button } from "@/components/atoms/Button/Button";
import { EyeIcon } from "@phosphor-icons/react"; import { EyeIcon } from "@phosphor-icons/react";
import { AgentActionsDropdown } from "../../AgentActionsDropdown"; import { AgentActionsDropdown } from "../../AgentActionsDropdown";
import { useScheduleDetailHeader } from "../../RunDetailHeader/useScheduleDetailHeader"; import { useScheduleDetailHeader } from "../../RunDetailHeader/useScheduleDetailHeader";
import { SelectedActionsWrap } from "../../SelectedActionsWrap";
type Props = { type Props = {
agent: LibraryAgent; agent: LibraryAgent;
@@ -19,7 +20,7 @@ export function SelectedScheduleActions({ agent, scheduleId }: Props) {
return ( return (
<> <>
<div className="my-4 flex flex-col items-center gap-3"> <SelectedActionsWrap>
{openInBuilderHref && ( {openInBuilderHref && (
<Button <Button
variant="icon" variant="icon"
@@ -32,7 +33,7 @@ export function SelectedScheduleActions({ agent, scheduleId }: Props) {
</Button> </Button>
)} )}
<AgentActionsDropdown agent={agent} scheduleId={scheduleId} /> <AgentActionsDropdown agent={agent} scheduleId={scheduleId} />
</div> </SelectedActionsWrap>
</> </>
); );
} }

View File

@@ -95,6 +95,7 @@ export function SelectedTemplateView({
return null; return null;
} }
const templateOrTrigger = agent.trigger_setup_info ? "Trigger" : "Template";
const hasWebhook = !!template.webhook_id && template.webhook; const hasWebhook = !!template.webhook_id && template.webhook;
return ( return (
@@ -111,14 +112,14 @@ export function SelectedTemplateView({
/> />
)} )}
<RunDetailCard title="Template Details"> <RunDetailCard title={`${templateOrTrigger} Details`}>
<div className="flex flex-col gap-2"> <div className="flex flex-col gap-2">
<Input <Input
id="template-name" id="template-name"
label="Name" label="Name"
value={name} value={name}
onChange={(e) => setName(e.target.value)} onChange={(e) => setName(e.target.value)}
placeholder="Enter template name" placeholder={`Enter ${templateOrTrigger.toLowerCase()} name`}
/> />
<Input <Input
@@ -128,7 +129,7 @@ export function SelectedTemplateView({
rows={3} rows={3}
value={description} value={description}
onChange={(e) => setDescription(e.target.value)} onChange={(e) => setDescription(e.target.value)}
placeholder="Enter template description" placeholder={`Enter ${templateOrTrigger.toLowerCase()} description`}
/> />
</div> </div>
</RunDetailCard> </RunDetailCard>

View File

@@ -15,6 +15,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { FloppyDiskIcon, PlayIcon, TrashIcon } from "@phosphor-icons/react"; import { FloppyDiskIcon, PlayIcon, TrashIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useState } from "react"; import { useState } from "react";
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
interface Props { interface Props {
agent: LibraryAgent; agent: LibraryAgent;
@@ -134,6 +135,7 @@ export function SelectedTemplateActions({
<TrashIcon weight="bold" size={18} /> <TrashIcon weight="bold" size={18} />
)} )}
</Button> </Button>
<AgentActionsDropdown agent={agent} />
</div> </div>
<Dialog <Dialog

View File

@@ -138,11 +138,21 @@ export function useSelectedTemplateView({
} }
function handleStartTask() { function handleStartTask() {
if (!query.data) return;
const inputsChanged =
JSON.stringify(inputs) !== JSON.stringify(query.data.inputs || {});
const credentialsChanged =
JSON.stringify(credentials) !==
JSON.stringify(query.data.credentials || {});
// Use changed unpersisted inputs if applicable
executeMutation.mutate({ executeMutation.mutate({
presetId: templateId, presetId: templateId,
data: { data: {
inputs: {}, inputs: inputsChanged ? inputs : undefined,
credential_inputs: {}, credential_inputs: credentialsChanged ? credentials : undefined,
}, },
}); });
} }

View File

@@ -15,6 +15,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
import { FloppyDiskIcon, TrashIcon } from "@phosphor-icons/react"; import { FloppyDiskIcon, TrashIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useState } from "react"; import { useState } from "react";
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
interface Props { interface Props {
agent: LibraryAgent; agent: LibraryAgent;
@@ -111,6 +112,7 @@ export function SelectedTriggerActions({
<TrashIcon weight="bold" size={18} /> <TrashIcon weight="bold" size={18} />
)} )}
</Button> </Button>
<AgentActionsDropdown agent={agent} />
</div> </div>
<Dialog <Dialog

View File

@@ -12,7 +12,7 @@ export function SelectedViewLayout(props: Props) {
return ( return (
<SectionWrap className="relative mb-3 flex min-h-0 flex-1 flex-col"> <SectionWrap className="relative mb-3 flex min-h-0 flex-1 flex-col">
<div <div
className={`${AGENT_LIBRARY_SECTION_PADDING_X} flex-shrink-0 border-b border-zinc-100 pb-4`} className={`${AGENT_LIBRARY_SECTION_PADDING_X} flex-shrink-0 border-b border-zinc-100 pb-0 lg:pb-4`}
> >
<Breadcrumbs <Breadcrumbs
items={[ items={[

Some files were not shown because too many files have changed in this diff Show More