Compare commits

...

69 Commits

Author SHA1 Message Date
Swifty
224411abd3 add updated_at as an option 2025-10-29 15:32:37 +01:00
Swifty
6b241af79e make store queries parameterised 2025-10-28 15:33:19 +01:00
Ubbe
320fb7d83a fix(frontend): waitlist modal copy (#11263)
### Changes 🏗️

### Before

<img width="800" height="649" alt="Screenshot_2025-10-23_at_00 44 59"
src="https://github.com/user-attachments/assets/fd717d39-772a-4331-bc54-4db15a9a3107"
/>

### After

<img width="800" height="555" alt="Screenshot 2025-10-27 at 23 19 10"
src="https://github.com/user-attachments/assets/64878bd0-3a96-4b3a-8344-1a88c89de52e"
/>

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Try to signup with a non-approved email
  - [x] You see the modal with an updated copy
2025-10-28 11:08:06 +00:00
Ubbe
54552248f7 fix(frontend): login not visible mobile (#11245)
## Changes 🏗️

The mobile 📱 experience is still a mess but this helps a little.

### Before

<img width="350" height="395" alt="Screenshot 2025-10-24 at 18 26 18"
src="https://github.com/user-attachments/assets/75eab232-8c37-41e7-a51d-dbe07db336a0"
/>

### After

<img width="350" height="406" alt="Screenshot 2025-10-24 at 18 25 54"
src="https://github.com/user-attachments/assets/ecbd8bbd-8a94-4775-b990-c8b51de48cf9"
/>


## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Load the app
  - [x] Check the Tally popup button copy
  - [x] The button still works
2025-10-28 14:00:50 +04:00
Ubbe
d8a5780ea2 fix(frontend): feedback button copy (#11246)
## Changes 🏗️

<img width="800" height="827" alt="Screenshot 2025-10-24 at 17 45 48"
src="https://github.com/user-attachments/assets/ab18361e-6c58-43e9-bea6-c9172d06c0e7"
/>

- Shows the text `Give feedback` so the button is more explicit 🏁 
- Refactor the component to stick to [new code
conventions](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/frontend/CONTRIBUTING.md)

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Load the app
  - [x] Check the Tally popup button copy
  - [x] The button still works
2025-10-28 14:00:33 +04:00
seer-by-sentry[bot]
377657f8a1 fix(backend): Extract response from LLM response dictionary (#11262)
### Changes 🏗️

- Modifies the LLM block to extract the actual response from the
dictionary returned by the LLM, instead of yielding the entire
dictionary. This addresses
[AUTOGPT-SERVER-6EY](https://sentry.io/organizations/significant-gravitas/issues/6950850822/).

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] After applying the fix, I ran the agent that triggered the Sentry
error and confirmed that it now completes successfully without errors.

---------

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
2025-10-28 08:43:29 +00:00
seer-by-sentry[bot]
ff71c940c9 fix(backend): Properly encode hostname in URL validation (#11259)
Fixes
[AUTOGPT-SERVER-6KZ](https://sentry.io/organizations/significant-gravitas/issues/6976926125/).
The issue was that: Redirect handling strips the URL scheme, causing
subsequent requests to fail validation and hit a 404.

- Ensures the hostname in the URL is properly IDNA-encoded after
validation.
- Reconstructs the netloc with the encoded hostname and preserves the
port if it exists.

This fix was generated by Seer in Sentry, triggered by Craig Swift. 👁️
Run ID: 2204774

Not quite right? [Click here to continue debugging with
Seer.](https://sentry.io/organizations/significant-gravitas/issues/6976926125/?seerDrawer=true)

### Changes 🏗️

**backend/util/request.py:**
- Fixed URL validation to properly preserve port numbers when
reconstructing netloc
- Ensures IDNA-encoded hostname is combined with port (if present)
before URL reconstruction

**Test Results:**
-  Tested request to https://www.target.com/ (original failing URL from
Sentry issue)
-  Status: 200, Content retrieved successfully (339,846 bytes)
-  Port preservation verified for URLs with explicit ports

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Tested request to https://www.target.com/ (original failing URL)
  - [x] Verified status code 200 and successful content retrieval
  - [x] Verified port preservation in URL validation

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

Co-authored-by: seer-by-sentry[bot] <157164994+seer-by-sentry[bot]@users.noreply.github.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
2025-10-28 08:43:14 +00:00
Reinier van der Leer
9967b3a7ce fix(frontend/builder): Fix unnecessary graph re-saving (#11145)
- Resolves #10980
- 2nd attempt after #11075 broke some things

Fixes unnecessary graph re-saving when no changes were made after
initial save. More specifically, this PR fixes two causes of this issue:
- Frontend node IDs were being compared to backend IDs, which won't
match if the graph has been modified and saved since loading.
- `fillDefaults` was being applied to all nodes (including existing
ones) on element creation, and empty values were being stripped
*post-save* with `removeEmptyStringsAndNulls`. This invisible
auto-modification of node input data meant that in some common cases the
graph would never be in sync with the backend.

### Changes 🏗️

- Fix node ID handling
- Use `node.data.backend_id ?? node.id` instead of `node.id` in
`prepareSaveableGraph`
    - Also map link source/sink IDs to their corresponding backend IDs
  - Add note about `node.data.backend_id` to `_saveAgent`
  - Use `node.data.backend_id || node.id` as display ID in `CustomNode`

- Prevent auto-modification of node input data on existing nodes
- Prune empty values (`undefined`, `null`, `""`) from node input data
*pre-save* instead of post-save
- Related: improve typing and functionality of
`fillObjectDefaultsFromSchema` (moved and renamed from `fillDefaults`)

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Node display ID updates on save
- [x] Clicking save a second time (without making more changes) doesn't
cause re-save
- [x] Updating nodes with dynamic input links (e.g. Create Dictionary
Block) doesn't make the links disappear


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Prevented unintended auto-modification of existing nodes during
editing
* Improved consistency of node and connection identifiers in saved
graphs

* **Improvements**
  * Enhanced node title display logic for clearer node identification
* Optimized data cleanup utilities for more robust input processing in
the builder

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-10-27 16:49:02 +00:00
Bently
9db443960a feat(blocks/claude): Remove Claude 3.5 Sonnet and Haiku model (#11260)
Removes CLAUDE_3_5_SONNET and CLAUDE_3_5_HAIKU from LlmModel enum, model
metadata, and cost configuration since they are deprecated

  ### Checklist 📋

  #### For code changes:
  - [x] I have clearly listed my changes in the PR description
  - [x] I have made a test plan
  - [x] I have tested my changes according to the test plan:
  - [x] Verify the models are gone from the llm blocks
2025-10-27 16:49:02 +00:00
Ubbe
9316100864 fix(frontend): agent activity graph names (#11233)
## Changes 🏗️

We weren't fetching all library agents, just the first 15... to compute
the agent map on the Agent Activity dropdown. We suspect that is causing
some agent executions coming as `Unknown agent`.

In this changes, I'm fetching all the library agents upfront ( _without
blocking page load_ ) and caching them on the browser, so we have all
the details to render the agent runs. This is re-used in the library as
well for fast initial load on the agents list page.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] First request populates cache; subsequent identical requests hit
cache
- [x] Editing an agent invalidates relevant cache keys and serves fresh
data
  - [x] Different query params generate distinct cache entries
  - [x] Cache layer gracefully falls back to live data on errors
  - [x] 404 behavior for unknown agents unchanged

### For configuration changes:

None
2025-10-27 20:08:21 +04:00
Ubbe
cbe0cee0fc fix(frontend): Credentials disabling onboarding Run button (#11244)
## Changes 🏗️

The onboarding `Run` button is disabled sometimes when an agent
requiring credentials is selected. We think this can be because the
credentials load _async_ by a sub-component ( `<CredentialsInputs />` ),
and there wasn't a way for the parent component to know whether they
loaded or not.

- Refactored **Step 5** of onboarding to adhere to our code conventions
  - split concerns and colocated state
  - used generated API hooks
  - the UI will only render once API calls succeed
- Created a system where ``<CredentialsInputs />` notify the parent
component when they load
- Did minor adjustments here and there

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] I will know once I find an agent with credentials that I can
run....


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added visual agent selection card displaying agent details during
onboarding
  * Introduced credentials input management for agent configuration
  * Added onboarding guidance for initiating agent runs

* **Improvements**
  * Enhanced onboarding flow with improved state management
  * Refined login state handling
  * Adjusted spacing in agent rating display

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-10-27 19:53:14 +04:00
Swifty
7cbb1ed859 fix(backend/store): Sanitize all sql terms (#11228)
Categories and Creators where not sanitized in the full text search

### Changes 🏗️

- apply sanitization to categories and creators

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] run tests to check it still works
2025-10-27 12:59:05 +01:00
Reinier van der Leer
e06e7ff33f fix(backend): Implement graceful shutdown in AppService to prevent RPC errors (#11240)
We're currently seeing errors in the `DatabaseManager` while it's
shutting down, like:

```
WARNING [DatabaseManager] Termination request: SystemExit; 0 executing cleanup.
INFO [DatabaseManager]  Disconnecting Database...
INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection started...
INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection completed successfully.
INFO [DatabaseManager] Terminated.
ERROR POST /create_or_add_to_user_notification_batch failed: Failed to create or add to notification batch for user {user_id} and type AGENT_RUN: NoneType: None
```

This indicates two issues:
- The service doesn't wait for pending RPC calls to finish before
terminating
- We're using `logger.exception` outside an error handling context,
causing the confusing and not much useful `NoneType: None` to be printed
instead of error info

### Changes 🏗️

- Implement graceful shutdown in `AppService` so in-flight RPC calls can
finish
  - Add tests for graceful shutdown
  - Prevent `AppService` accepting new requests during shutdown
- Rework `AppService` lifecycle management; add support for async
`lifespan`
- Fix `AppService` endpoint error logging
- Improve logging in `AppProcess` and `AppService`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- Deploy to Dev cluster, then `kubectl rollout restart` the different
services a few times
    - [x] -> `DatabaseManager` doesn't break on re-deployment
    - [x] -> `Scheduler` doesn't break on re-deployment
    - [x] -> `NotificationManager` doesn't break on re-deployment
2025-10-25 14:47:19 +00:00
Abhimanyu Yadav
acb946801b feat(frontend): add agent execution functionality in new builder (#11186)
This PR implements real-time agent execution functionality in the new
flow editor, enabling users to run, monitor, and view results of their
agent workflows directly within the builder interface.


https://github.com/user-attachments/assets/8a730e08-f88d-49d4-be31-980e2c7a2f83

#### Key Features Added:

##### 1. **Agent Execution Controls**
- Added "Run Agent" / "Stop Agent" button with gradient styling in the
builder interface
- Implemented execution state management through a new `graphStore` for
tracking running status
- Save graph automatically before execution to ensure latest changes are
persisted

##### 2. **Real-time Execution Monitoring**
- Implemented WebSocket-based real-time updates for node execution
status via `useFlowRealtime` hook
- Subscribe to graph execution events and node execution events for live
status tracking
- Visual execution status badges on nodes showing states: `QUEUED`,
`RUNNING`, `COMPLETED`, `FAILED`, etc.
   - Animated gradient border effect when agent is actively running

##### 3. **Node Execution Results Display**
- New `NodeDataRenderer` component to display input/output data for each
executed node
   - Collapsible result sections with formatted JSON display
- Prepared UI for future functionality: copy, info, and expand actions
for node data

#### Technical Implementation:

- **State Management**: Extended `nodeStore` with execution status and
result tracking methods
- **WebSocket Integration**: Real-time communication for execution
updates without polling
- **Component Architecture**: Modular components for execution controls,
status display, and result rendering
- **Visual Feedback**: Color-coded status badges and animated borders
for clear execution state indication


#### TODO Items for Future PRs:
- Complete implementation of node result action buttons (copy, info,
expand)
- Add agent output display component
- Implement schedule run functionality
- Handle credential and input parameters for graph execution
- Add tooltips for better UX

### Checklist

- [x] Create a new agent with at least 3 blocks and verify execution
starts correctly
- [x] Verify real-time status updates appear on nodes during execution
- [x] Confirm execution results display in the node output sections
- [x] Verify the animated border appears when agent is running
- [x] Check that node status badges show correct states (QUEUED,
RUNNING, COMPLETED, etc.)
- [x] Test WebSocket reconnection after connection loss
- [x] Verify graph is saved before execution begins
2025-10-24 12:05:09 +00:00
Bently
48ff225837 feat(blocks/revid): Add cost configs for revid video blocks (#11242)
Updated block costs in `backend/backend/data/block_cost_config.py`:
  - **AIShortformVideoCreatorBlock**: Updated from 50 credits to 307
  - **AIAdMakerVideoCreatorBlock**: Added cost of 714 credits
  - **AIScreenshotToVideoAdBlock**: Added cost of 612 credits

  ### Checklist 📋

  #### For code changes:
  - [x] I have clearly listed my changes in the PR description
  - [x] I have made a test plan
  - [x] I have tested my changes according to the test plan:
- [x] Verify AIShortformVideoCreatorBlock costs 307 credits when
executed
- [x] Verify AIAdMakerVideoCreatorBlock costs 714 credits when executed
- [x] Verify AIScreenshotToVideoAdBlock costs 612 credits when executed
2025-10-23 09:46:22 +00:00
Nicholas Tindle
e2a9923f30 feat(frontend): Improve waitlist error display & messages (#11206)
Improves the "not on waitlist" error display based on feedback.

- Follow-up to #11198
  - Follow-up to #11196

### Changes 🏗️

- Use standard `ErrorCard`
- Improve text strings
- Merge `isWaitlistError` and `isWaitlistErrorFromParams`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] We need to test in dev becasue we don't have a waitlist locally
and will revert if it doesnt work
- deploy to dev environment and sign up with a non approved account and
see if error appears
2025-10-22 13:37:42 +00:00
Reinier van der Leer
39792d517e fix(frontend): Filter out undefined query params in API requests (#11238)
Part of our effort to eliminate preventable warnings and errors.

- Resolves #11237

### Changes 🏗️

- Exclude `undefined` query params in API requests

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Open the Builder without a `flowVersion` URL parameter
    - [x] -> `GET /api/library/agents/by-graph/{graph_id}` succeeds
  - Open the builder with a `flowVersion` URL parameter
    - [x] -> version is correctly included in request URL parameters
2025-10-22 13:25:34 +00:00
Bently
a6a2f71458 Merge commit from fork
* Replace urllib with Requests in RSS block to prevent SSRF

* Format
2025-10-22 14:18:34 +01:00
Bently
788b861bb7 Merge commit from fork 2025-10-22 14:17:26 +01:00
Ubbe
e203e65dc4 feat(frontend): setup datafast custom events (#11231)
## Changes 🏗️

- Add [custom events](https://datafa.st/docs/custom-goals) in
**Datafa.st** to track the user journey around core actions
  - track `add_to_library`
  - track `download_agent`
  - track `run_agent`
  - track `schedule_agent` 
- Refactor the analytics service to encapsulate both **GA** and
**Datafa.st**

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Analytics load correctly locally
  - [x] Events fire in production
 
### For configuration changes:

Once deployed to production we need to verify we are receiving analytics
and custom events in [Datafa.st](https://datafa.st/)
2025-10-22 16:56:30 +04:00
Ubbe
bd03697ff2 fix(frontend): URL substring sanitazion issue (#11232)
Potential fix for
[https://github.com/Significant-Gravitas/AutoGPT/security/code-scanning/145](https://github.com/Significant-Gravitas/AutoGPT/security/code-scanning/145)

To fix the issue, rather than using substring matching on the raw URL
string, we need to properly parse the URL and inspect its hostname. We
should confirm that the `hostname` property of the parsed URL is equal
to either `vimeo.com` or explicitly permitted subdomains like
`www.vimeo.com`. We can use the native JavaScript `URL` class for robust
parsing.

**File/Location:**  
- Only change line(s) in
`autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/AgentRunsView/components/OutputRenderers/renderers/MarkdownRenderer.tsx`
- Specifically, update the logic in function `isVideoUrl()` on line 45.

**Methods/Imports/Definitions:**  
- Use the standard `URL` class (no need to add a new import, as this is
available in browsers and in Node.js).
- Provide fallback in case the URL passed in is malformed (wrap in a
try-catch, treat as non-video in this case).
- Check the parsed hostname for equality with `vimeo.com` or,
optionally, specific allowed subdomains (`www.vimeo.com`).

---


_Suggested fixes powered by Copilot Autofix. Review carefully before
merging._

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-10-22 16:56:12 +04:00
Reinier van der Leer
efd37b7a36 fix(frontend): Limit Sentry console capture to warnings and errors (#11223)
Debug and info level messages are currently ending up in Sentry,
polluting our issue feed.

### Changes 🏗️

- Limit Sentry console capture to warnings and worse

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - Trivial change, no test needed
2025-10-22 09:49:25 +00:00
Zamil Majdy
bb0b45d7f7 fix(backend): Make Jinja Error on TextFormatter as value error (#11236)
<!-- Clearly explain the need for these changes: -->

This PR converts Jinja2 TemplateError exceptions to ValueError in the
TextFormatter class to ensure proper error handling and HTTP status code
responses (400 instead of 500).

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

- Added import for `jinja2.exceptions.TemplateError` in
`backend/util/text.py:6`
- Wrapped template rendering in try-catch block in `format_string`
method (`backend/util/text.py:105-109`)
- Convert `TemplateError` to `ValueError` to ensure proper 400 HTTP
status code for client errors
- Added warning logging for template rendering errors before re-raising
as ValueError

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan: -->
- [x] Verified that invalid Jinja2 templates now raise ValueError
instead of TemplateError
  - [x] Confirmed that valid templates continue to work correctly
  - [x] Checked that warning logs are generated for template errors
  - [x] Validated that the exception chain is preserved with `from e`

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
2025-10-22 09:38:02 +00:00
Reinier van der Leer
04df981115 fix(backend): Fix structured logging for cloud environments (#11227)
- Resolves #11226

### Changes 🏗️

- Drop use of `CloudLoggingHandler` which docs state isn't for use in
GKE
- For cloud logging, output only structured log entries to `stdout`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Test deploy to dev and check logs
2025-10-21 12:48:41 +00:00
Swifty
d25997b4f2 Revert "Merge branch 'swiftyos/secrt-1709-store-provider-names-and-en… (#11225)
Changes to providers blocks to store in db

### Changes 🏗️

- revet change

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] I have reverted the merge
2025-10-21 09:12:00 +00:00
Zamil Majdy
11d55f6055 fix(backend/executor): Avoid running direct query in executor (#11224)
## Summary
- Fixes database connection warnings in executor logs: "Client is not
connected to the query engine, you must call `connect()` before
attempting to query data"
- Implements resilient database client pattern already used elsewhere in
the codebase
- Adds caching to reduce database load for user context lookups

## Changes
- Updated `get_user_context()` to check `prisma.is_connected()` and fall
back to database manager client
- Added `@cached(maxsize=1000, ttl_seconds=3600)` decorator for
performance optimization
- Updated database manager to expose `get_user_by_id` method

## Test plan
- [x] Verify executor pods no longer show Prisma connection warnings
- [x] Confirm user timezone is still correctly retrieved
- [x] Test fallback behavior when Prisma is disconnected

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-21 08:46:40 +00:00
Ubbe
063dc5cf65 refactor(frontend): standardise with environment service (#11209)
## Changes 🏗️

Standardize all the runtime environment checks on the Front-end and
associated conditions to run against a single environment service where
all the environment config is centralized and hence easier to manage.

This helps prevent typos and bug when manually asserting against
environment variables ( which are typed as `string` ), the helper
functions are easier to read and re-use across the codebase.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run the app and click around
  - [x] Everything is smooth
  - [x] Test on the CI and types are green  

### For configuration changes:

None 🙏🏽
2025-10-21 08:44:34 +00:00
Ubbe
b7646f3e58 docs(frontend): contributing guidelines (#11210)
## Changes 🏗️

Document how to contribute on the Front-end so it is easier for
non-regular contributors.

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Contribution guidelines make sense and look good considering the
AutoGPT stack

### For configuration changes:

None
2025-10-21 08:26:51 +00:00
Ubbe
0befaf0a47 feat(frontend): update tooltip and alert styles (#11212)
## Changes 🏗️

Matching updated changes in AutoGPT design system:

<img width="283" height="156" alt="Screenshot 2025-10-20 at 23 55 15"
src="https://github.com/user-attachments/assets/3a2e0ee7-cd53-4552-b72b-42f4631f1503"
/>
<img width="427" height="92" alt="Screenshot 2025-10-20 at 23 55 25"
src="https://github.com/user-attachments/assets/95344765-2155-4861-abdd-f5ec1497ace2"
/>
<img width="472" height="85" alt="Screenshot 2025-10-20 at 23 55 30"
src="https://github.com/user-attachments/assets/31084b40-0eea-4feb-a627-c5014790c40d"
/>
<img width="370" height="87" alt="Screenshot 2025-10-20 at 23 55 35"
src="https://github.com/user-attachments/assets/a81dba12-a792-4d41-b269-0bc32fc81271"
/>


## Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Check the stories for Tooltip and Alerts, they look good


#### For configuration changes:
None
2025-10-21 08:14:28 +00:00
Reinier van der Leer
93f58dec5e Merge branch 'master' into dev 2025-10-21 08:49:12 +02:00
Reinier van der Leer
3da595f599 fix(backend): Only try to initialize LaunchDarkly once (#11222)
We currently try to re-init the LaunchDarkly client every time a feature flag is checked.
This causes 5 second extra latency on the flag check when LD is down, such as now.
Since flag checks are performed on every block execution, this currently cripples the platform's executors.

- Follow-up to #11221

### Changes 🏗️

- Only try to init LaunchDarkly once
- Improve surrounding log statements in the `feature_flag` module

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - This is a critical hotfix; we'll see its effect once deployed
2025-10-21 08:46:07 +02:00
Reinier van der Leer
e5e60921a3 fix(backend): Handle LaunchDarkly init failure (#11221)
LaunchDarkly is currently down and it's keeping our executor pods from
spinning up.

### Changes 🏗️

- Wrap `LaunchDarklyIntegration` init in a try/except

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - We'll see if it works once it deploys
2025-10-21 07:53:40 +02:00
Copilot
90af8f8e1a feat(backend): Add language fallback for YouTube transcription block (#11057)
## Problem

The YouTube transcription block would fail when attempting to transcribe
videos that only had transcripts available in non-English languages.
Even when usable transcripts existed in other languages, the block would
raise a `NoTranscriptFound` error because it only requested English
transcripts.

**Example video that would fail:**
https://www.youtube.com/watch?v=3AMl5d2NKpQ (only has Hungarian
transcripts)

**Error message:**
```
Could not retrieve a transcript for the video https://www.youtube.com/watch?v=3AMl5d2NKpQ! 
No transcripts were found for any of the requested language codes: ('en',)

For this video (3AMl5d2NKpQ) transcripts are available in the following languages:
(GENERATED) - hu ("Hungarian (auto-generated)")
```

## Solution

Implemented intelligent language fallback in the
`TranscribeYoutubeVideoBlock.get_transcript()` method:

1. **First**, tries to fetch English transcript (maintains backward
compatibility)
2. **If English unavailable**, lists all available transcripts and
selects the first one using this priority:
   - Manually created transcripts (any language)
   - Auto-generated transcripts (any language)
3. **Only fails** if no transcripts exist at all

**Example behavior:**
```python
# Before: Video with only Hungarian transcript
get_transcript("3AMl5d2NKpQ")  #  Raises NoTranscriptFound

# After: Video with only Hungarian transcript  
get_transcript("3AMl5d2NKpQ")  #  Returns Hungarian transcript
```

## Changes

- **Modified** `backend/blocks/youtube.py`: Added try-catch logic to
fallback to any available language when English is not found
- **Added** `test/blocks/test_youtube.py`: Comprehensive test suite
covering URL extraction, language fallback, transcript preferences, and
error handling (7 tests)
- **Updated** `docs/content/platform/blocks/youtube.md`: Documented the
language fallback behavior and transcript priority order

## Testing

-  All 7 new unit tests pass
-  Block integration test passes
-  Full test suite: 621 passed, 0 failed (no regressions)
-  Code formatting and linting pass

## Impact

This fix enables the YouTube transcription block to work with
international content while maintaining full backward compatibility:

-  Videos in any language can now be transcribed
-  English is still preferred when available
-  No breaking changes to existing functionality
-  Graceful degradation to available languages

Fixes #10637
Fixes https://linear.app/autogpt/issue/OPEN-2626

> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses (expand for details)</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `www.youtube.com`
> - Triggering command:
`/home/REDACTED/.cache/pypoetry/virtualenvs/autogpt-platform-backend-Ajv4iu2i-py3.11/bin/python3`
(dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to the custom allowlist in this
repository's [Copilot coding agent
settings](https://github.com/Significant-Gravitas/AutoGPT/settings/copilot/coding_agent)
(admins only)
>
> </details>

<!-- START COPILOT CODING AGENT SUFFIX -->



<details>

<summary>Original prompt</summary>

> Issue Title: if theres only one lanague available for transcribe
youtube return that langage not an error
> Issue Description: `Could not retrieve a transcript for the video
https://www.youtube.com/watch?v=3AMl5d2NKpQ! This is most likely caused
by: No transcripts were found for any of the requested language codes:
('en',) For this video (3AMl5d2NKpQ) transcripts are available in the
following languages: (MANUALLY CREATED) None (GENERATED) - hu
("Hungarian (auto-generated)") (TRANSLATION LANGUAGES) None If you are
sure that the described cause is not responsible for this error and that
a transcript should be retrievable, please create an issue at
https://github.com/jdepoix/youtube-transcript-api/issues. Please add
which version of youtube_transcript_api you are using and provide the
information needed to replicate the error. Also make sure that there are
no open issues which already describe your problem!` you can use this
video to test:
[https://www.youtube.com/watch?v=3AMl5d2NKpQ\`](https://www.youtube.com/watch?v=3AMl5d2NKpQ%60)
> Fixes
https://linear.app/autogpt/issue/OPEN-2626/if-theres-only-one-lanague-available-for-transcribe-youtube-return
> 
> 
> Comment by User :
> This thread is for an agent session with githubcopilotcodingagent.
> 
> Comment by User :
> This thread is for an agent session with githubcopilotcodingagent.
> 
> Comment by User :
> This comment thread is synced to a corresponding [GitHub
issue](https://github.com/Significant-Gravitas/AutoGPT/issues/10637).
All replies are displayed in both locations.
> 
> 


</details>


<!-- START COPILOT CODING AGENT TIPS -->
---

 Let Copilot coding agent [set things up for
you](https://github.com/Significant-Gravitas/AutoGPT/issues/new?title=+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot)
— coding agent works faster and does higher quality work when set up for
your repo.

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ntindle <8845353+ntindle@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-10-21 02:31:33 +00:00
Nicholas Tindle
eba67e0a4b fix(platform/blocks): update linear oauth to use refresh tokens (#10998)
<!-- Clearly explain the need for these changes: -->

### Need 💡

This PR addresses Linear issue SECRT-1665, which mandates an update to
Linear's OAuth2 implementation. Linear is transitioning from long-lived
access tokens to short-lived access tokens with refresh tokens, with a
deadline of April 1, 2026. This change is crucial to ensure continued
integration with Linear and to support their new token management
system, including a migration path for existing long-lived tokens.

### Changes 🏗️

-   **`autogpt_platform/backend/backend/blocks/linear/_oauth.py`**:
- Implemented full support for refresh tokens, including HTTP Basic
Authentication for token refresh requests.
- Added `migrate_old_token()` method to exchange old long-lived access
tokens for new short-lived tokens with refresh tokens using Linear's
`/oauth/migrate_old_token` endpoint.
- Enhanced `get_access_token()` to automatically detect and attempt
migration for old tokens, and to refresh short-lived tokens when they
expire.
    -   Improved error handling and token expiration management.
- Updated `_request_tokens` to handle both authorization code and
refresh token flows, supporting Linear's recommended authentication
methods.
-   **`autogpt_platform/backend/backend/blocks/linear/_config.py`**:
- Updated `TEST_CREDENTIALS_OAUTH` mock data to include realistic
`access_token_expires_at` and `refresh_token` for testing the new token
lifecycle.
-   **`LINEAR_OAUTH_IMPLEMENTATION.md`**:
- Added documentation detailing the new Linear OAuth refresh token
implementation, including technical details, migration strategy, and
testing notes.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verified OAuth URL generation and parameter encoding.
- [x] Confirmed HTTP Basic Authentication header creation for refresh
requests.
  - [x] Tested token expiration logic with a 5-minute buffer.
  - [x] Validated migration detection for old vs. new token types.
  - [x] Checked code syntax and import compatibility.

#### For configuration changes:

- [ ] `.env.default` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

---
Linear Issue: [SECRT-1665](https://linear.app/autogpt/issue/SECRT-1665)

<a
href="https://cursor.com/background-agent?bcId=bc-95f4c668-f7fa-4057-87e5-622ac81c0783"><picture><source
media="(prefers-color-scheme: dark)"
srcset="https://cursor.com/open-in-cursor-dark.svg"><source
media="(prefers-color-scheme: light)"
srcset="https://cursor.com/open-in-cursor-light.svg"><img alt="Open in
Cursor"
src="https://cursor.com/open-in-cursor.svg"></picture></a>&nbsp;<a
href="https://cursor.com/agents?id=bc-95f4c668-f7fa-4057-87e5-622ac81c0783"><picture><source
media="(prefers-color-scheme: dark)"
srcset="https://cursor.com/open-in-web-dark.svg"><source
media="(prefers-color-scheme: light)"
srcset="https://cursor.com/open-in-web-light.svg"><img alt="Open in Web"
src="https://cursor.com/open-in-web.svg"></picture></a>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <ntindle@users.noreply.github.com>
Co-authored-by: Bentlybro <Github@bentlybro.com>
2025-10-20 20:44:58 +00:00
Nicholas Tindle
47bb89caeb fix(backend): Disable LaunchDarkly integration in metrics.py (#11217) 2025-10-20 14:07:21 -05:00
Ubbe
271a520afa feat(frontend): setup DataFast analytics (#11182)
## Changes 🏗️

Following https://datafa.st/docs/nextjs-app-router

## Checklist 📋

### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] We will see once we make a production deployment and get data into
the platform

### For configuration changes:

None
2025-10-20 16:18:04 +04:00
Swifty
3988057032 Merge branch 'swiftyos/secrt-1712-remove-error-handling-form-store-routes' into dev 2025-10-18 12:28:25 +02:00
Swifty
a6c6e48f00 Merge branch 'swiftyos/open-2791-featplatform-add-easy-test-data-creation' into dev 2025-10-18 12:28:17 +02:00
Swifty
e72ce2f9e7 Merge branch 'swiftyos/secrt-1709-store-provider-names-and-env-vars-in-db' into dev 2025-10-18 12:27:58 +02:00
Swifty
bd7a79a920 Merge branch 'swiftyos/secrt-1706-improve-store-search' into dev 2025-10-18 12:27:31 +02:00
Nicholas Tindle
3f546ae845 fix(frontend): improve waitlist error display for users not on allowlist (#11198)
fix issue with identifying errors :(
### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] we have to test in dev due to waitlist integration, so we are
merging. will revert if fails

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-10-18 05:14:05 +00:00
Nicholas Tindle
097a19141d fix(frontend): improve waitlist error display for users not on allowlist (#11196)
## Summary

This PR improves the user experience for users who are not on the
waitlist during sign-up. When a user attempts to sign up or log in with
an email that's not on the allowlist, they now see a clear, helpful
modal with a direct call-to-action to join the waitlist.

Fixes
[OPEN-2794](https://linear.app/autogpt/issue/OPEN-2794/display-waitlist-error-for-users-not-on-waitlist-during-sign-up)

## Changes

-  Updated `EmailNotAllowedModal` with improved messaging and a "Join
Waitlist" button
- 🔧 Fixed OAuth provider signup/login to properly display the waitlist
modal
- 📝 Enhanced auth-code-error page to detect and display
waitlist-specific errors
- 💬 Added helpful guidance about checking email address and Discord
support link
- 🎯 Consistent waitlist error handling across all auth flows (regular
signup, OAuth, error pages)

## Test Plan

Tested locally by:
1. Attempting signup with non-allowlisted email - modal appears 
2. Attempting Google SSO with non-allowlisted account - modal appears 
3. Modal shows "Join Waitlist" button that opens
https://agpt.co/waitlist 
4. Help text about checking email and Discord support is visible 

## Screenshots

The new waitlist modal includes:
- Clear "Join the Waitlist" title
- Explanation that platform is in closed beta
- "Join Waitlist" button (opens in new tab)
- Help text about checking email address
- Discord support link for users who need help

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-10-18 03:37:31 +00:00
Swifty
c958c95d6b fix incorrect type import 2025-10-17 20:36:49 +02:00
Swifty
3e50cbd2cb fix import 2025-10-17 19:19:17 +02:00
Swifty
1b69f1644d revert frontend type change 2025-10-17 17:26:08 +02:00
Swifty
d9035a233c Merge branch 'swiftyos/secrt-1709-store-provider-names-and-env-vars-in-db' of github.com:Significant-Gravitas/AutoGPT into swiftyos/secrt-1709-store-provider-names-and-env-vars-in-db 2025-10-17 17:20:27 +02:00
Swifty
972cbfc3de fix tests 2025-10-17 17:20:05 +02:00
Swifty
8f861b1bb2 removed error handling from routes 2025-10-17 17:08:17 +02:00
Swifty
fa2731bb8b Merge branch 'dev' into swiftyos/secrt-1709-store-provider-names-and-env-vars-in-db 2025-10-17 17:06:09 +02:00
Swifty
2dc0c97a52 Add block registry and updated 2025-10-17 16:49:04 +02:00
Zamil Majdy
0bb2b87c32 fix(backend): resolve UserBalance migration issues and credit spending bug (#11192)
## Summary
Fix critical UserBalance migration and spending issues affecting users
with credits from transaction history but no UserBalance records.

## Root Issues Fixed

### Issue 1: UserBalance Migration Complexity
- **Problem**: Complex data migration with timestamp logic issues and
potential race conditions
- **Solution**: Simplified to idempotent table creation only,
application handles auto-population

### Issue 2: Credit Spending Bug  
- **Problem**: Users with $10.0 from transaction history couldn't spend
$0.16
- **Root Cause**: `_add_transaction` and `_enable_transaction` only
checked UserBalance table, returning 0 balance for users without records
- **Solution**: Enhanced both methods with transaction history fallback
logic

### Issue 3: Exception Handling Inconsistency
- **Problem**: Raw SQL unique violations raised different exception
types than Prisma ORM
- **Solution**: Convert raw SQL unique violations to
`UniqueViolationError` at source

## Changes Made

### Migration Cleanup
- **Idempotent operations**: Use `CREATE TABLE IF NOT EXISTS`, `CREATE
INDEX IF NOT EXISTS`
- **Inline foreign key**: Define constraint within `CREATE TABLE`
instead of separate `ALTER TABLE`
- **Removed data migration**: Application creates UserBalance records
on-demand
- **Safe to re-run**: No errors if table/index/constraint already exists

### Credit Logic Fixes
- **Enhanced `_add_transaction`**: Added transaction history fallback in
`user_balance_lock` CTE
- **Enhanced `_enable_transaction`**: Added same fallback logic for
payment fulfillment
- **Exception normalization**: Convert raw SQL unique violations to
`UniqueViolationError`
- **Simplified `onboarding_reward`**: Use standardized
`UniqueViolationError` catching

### SQL Fallback Pattern
```sql
COALESCE(
    (SELECT balance FROM UserBalance WHERE userId = ? FOR UPDATE),
    -- Fallback: compute from transaction history if UserBalance doesn't exist
    (SELECT COALESCE(ct.runningBalance, 0) 
     FROM CreditTransaction ct 
     WHERE ct.userId = ? AND ct.isActive = true AND ct.runningBalance IS NOT NULL 
     ORDER BY ct.createdAt DESC LIMIT 1),
    0
) as balance
```

## Impact

### Before
-  Users with transaction history but no UserBalance couldn't spend
credits
-  Migration had complex timestamp logic with potential bugs  
-  Raw SQL and Prisma exceptions handled differently
-  Error: "Insufficient balance of $10.0, where this will cost $0.16"

### After  
-  Seamless spending for all users regardless of UserBalance record
existence
-  Simple, idempotent migration that's safe to re-run
-  Consistent exception handling across all credit operations
-  Automatic UserBalance record creation during first transaction
-  Backward compatible - existing users unaffected

## Business Value
- **Eliminates user frustration**: Users can spend their credits
immediately
- **Smooth migration path**: From old User.balance to new UserBalance
table
- **Better reliability**: Atomic operations with proper error handling
- **Maintainable code**: Consistent patterns across credit operations

## Test Plan
- [ ] Manual testing with users who have transaction history but no
UserBalance records
- [ ] Verify migration can be run multiple times safely
- [ ] Test spending credits works for all user scenarios
- [ ] Verify payment fulfillment (`_enable_transaction`) works correctly
- [ ] Add comprehensive test coverage for this scenario

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 19:46:13 +07:00
Swifty
a1d9b45238 updated openapi spec 2025-10-17 14:01:37 +02:00
Swifty
29895c290f store providers in db 2025-10-17 13:34:35 +02:00
Zamil Majdy
73c0b6899a fix(backend): Remove advisory locks for atomic credit operations (#11143)
## Problem
High QPS failures on `spend_credits` operations due to lock contention
from `pg_advisory_xact_lock` causing serialization and seconds of wait
time.

## Solution 
Replace PostgreSQL advisory locks with atomic database operations using
CTEs (Common Table Expressions).

### Key Changes
- **Add persistent balance column** to User table for O(1) balance
lookups
- **Atomic CTE-based operations** for all credit transactions using
UPDATE...RETURNING pattern
- **Comprehensive concurrency tests** with 7 test scenarios including
stress testing
- **Remove all advisory lock usage** from the credit system

### Implementation Details
1. **Migration**: Adds balance column with backfill from transaction
history
2. **Atomic Operations**: All credit operations now use single atomic
CTEs that update balance and create transaction in one query
3. **Race Condition Prevention**: WHERE clauses in UPDATE statements
ensure balance never goes negative
4. **BetaUserCredit Compatibility**: Preserved monthly refill logic with
updated `_add_transaction` signature

### Performance Impact
-  Eliminated lock contention bottlenecks
-  O(1) balance lookups instead of O(n) transaction aggregation  
-  Atomic operations prevent race conditions without locks
-  Supports high QPS without serialization delays

### Testing
- All existing tests pass
- New concurrency test suite (`credit_concurrency_test.py`) with:
  - Concurrent spends from same user
  - Insufficient balance handling
  - Mixed operations (spends, top-ups, balance checks)
  - Race condition prevention
  - Integer overflow protection
  - Stress testing with 100 concurrent operations

### Breaking Changes
None - all existing APIs maintain compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Enhanced top‑up flows with top‑up types, clearer credit→dollar
formatting, and idempotent onboarding rewards.

* **Bug Fixes**
* Fixed race conditions for concurrent spends/top‑ups, added
integer‑overflow and underflow protection, stronger input validation,
and improved refund/dispute handling.

* **Refactor**
* Persisted per‑user balance with atomic updates for reliable balances;
admin history now prefetches balances.

* **Tests**
* Added extensive concurrency, refund, ceiling/underflow and migration
test suites.

* **Chores**
* Database migration to add persisted user balance; APIKey status
extended (SUSPENDED).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Swifty <craigswift13@gmail.com>
2025-10-17 17:05:05 +07:00
Zamil Majdy
4c853a54d7 Merge commit 'e4bc728d40332e7c2b1edec5f1b200f1917950e2' into HEAD 2025-10-17 16:43:23 +07:00
Zamil Majdy
dfdd632161 fix(backend/util): handle nested Pydantic models in SafeJson (#11188)
## Summary

Fixes a critical serialization bug introduced in PR #11187 where
`SafeJson` failed to serialize dictionaries containing Pydantic models,
causing 500 Internal Server Errors in the executor service.

## Problem

The error manifested as:
```
CRITICAL: Operation Approaching Failure Threshold: Service communication: '_call_method_async'
Current attempt: 50/50
Error: HTTPServerError: HTTP 500: Server error '500 Internal Server Error' 
for url 'http://autogpt-database-manager.prod-agpt.svc.cluster.local:8005/create_graph_execution'
```

Root cause in `create_graph_execution`
(backend/data/execution.py:656-657):
```python
"credentialInputs": SafeJson(credential_inputs) if credential_inputs else Json({})
```

Where `credential_inputs: Mapping[str, CredentialsMetaInput]` is a dict
containing Pydantic models.

After PR #11187's refactor, `_sanitize_value()` only converted top-level
BaseModel instances to dicts, but didn't handle BaseModel instances
nested inside dicts/lists/tuples. This caused Prisma's JSON serializer
to fail with:
```
TypeError: Type <class 'backend.data.model.CredentialsMetaInput'> not serializable
```

## Solution

Added BaseModel handling to `_sanitize_value()` to recursively convert
Pydantic models to dicts before sanitizing:

```python
elif isinstance(value, BaseModel):
    # Convert Pydantic models to dict and recursively sanitize
    return _sanitize_value(value.model_dump(exclude_none=True))
```

This ensures all nested Pydantic models are properly serialized
regardless of nesting depth.

## Changes

- **backend/util/json.py**: Added BaseModel check to `_sanitize_value()`
function
- **backend/util/test_json.py**: Added 6 comprehensive tests covering:
  - Dict containing Pydantic models
  - Deeply nested Pydantic models  
  - Lists of Pydantic models in dicts
  - The exact CredentialsMetaInput scenario
  - Complex mixed structures
  - Models with control characters

## Testing

 All new tests pass  
 Verified fix resolves the production 500 error  
 Code formatted with `poetry run format`

## Related

- Fixes issues introduced in PR #11187
- Related to executor service 500 errors in production

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Bentlybro <Github@bentlybro.com>
Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 09:27:09 +00:00
Swifty
1ed224d481 simplify test and add reset-db make command 2025-10-17 11:12:00 +02:00
Swifty
3b5d919399 fix formatting 2025-10-17 10:56:45 +02:00
Swifty
3c16de22ef add test data creation to makefile and test it 2025-10-17 10:51:58 +02:00
Zamil Majdy
e4bc728d40 Revert "Revert "fix(backend/util): rewrite SafeJson to prevent Invalid \escape errors (#11187)""
This reverts commit 8258338caf.
2025-10-17 15:25:30 +07:00
Swifty
2c6d85d15e feat(platform): Shared cache (#11150)
### Problem
When running multiple backend pods in production, requests can be routed
to different pods causing inconsistent cache states. Additionally, the
current cache implementation in `autogpt_libs` doesn't support shared
caching across processes, leading to data inconsistency and redundant
cache misses.

### Changes 🏗️

- **Moved cache implementation from autogpt_libs to backend**
(`/backend/backend/util/cache.py`)
  - Removed `/autogpt_libs/autogpt_libs/utils/cache.py`
  - Centralized cache utilities within the backend module
  - Updated all import statements across the codebase

- **Implemented Redis-based shared caching**
- Added `shared_cache` parameter to `@cached` decorator for
cross-process caching
  - Implemented Redis connection pooling for efficient cache operations
  - Added support for cache key pattern matching and bulk deletion
  - Added TTL refresh on cache access with `refresh_ttl_on_get` option

- **Enhanced cache functionality**
  - Added thundering herd protection with double-checked locking
  - Implemented thread-local caching with `@thread_cached` decorator
- Added cache management methods: `cache_clear()`, `cache_info()`,
`cache_delete()`
  - Added support for both sync and async functions

- **Updated store caching** (`/backend/server/v2/store/cache.py`)
  - Enabled shared caching for all store-related cache functions
  - Set appropriate TTL values (5-15 minutes) for different cache types
  - Added `clear_all_caches()` function for cache invalidation

- **Added Redis configuration**
  - Added Redis connection settings to backend settings
  - Configured dedicated connection pool for cache operations
  - Set up binary mode for pickle serialization

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Verify Redis connection and cache operations work correctly
  - [x] Test shared cache across multiple backend instances
  - [x] Verify cache invalidation with `clear_all_caches()`
- [x] Run cache tests: `poetry run pytest
backend/backend/util/cache_test.py`
  - [x] Test thundering herd protection under concurrent load
  - [x] Verify TTL refresh functionality with `refresh_ttl_on_get=True`
  - [x] Test thread-local caching for request-scoped data
  - [x] Ensure no performance regression vs in-memory cache

#### For configuration changes:

- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes (Redis already configured)
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
- Redis cache configuration uses existing Redis service settings
(REDIS_HOST, REDIS_PORT, REDIS_PASSWORD)
  - No new environment variables required
2025-10-17 07:56:01 +00:00
Zamil Majdy
374f35874c feat(platform): Add LaunchDarkly flag for platform payment system (#11181)
## Summary

Implement selective rollout of payment functionality using LaunchDarkly
feature flags to enable gradual deployment to pilot users.

- Add `ENABLE_PLATFORM_PAYMENT` flag to control credit system behavior
- Update `get_user_credit_model` to use user-specific flag evaluation  
- Replace hardcoded `NEXT_PUBLIC_SHOW_BILLING_PAGE` with LaunchDarkly
flag
- Enable payment UI components only for flagged users
- Maintain backward compatibility with existing beta credit system
- Default to beta monthly credits when flag is disabled
- Fix tests to work with new async credit model function

## Key Changes

### Backend
- **Credit Model Selection**: The `get_user_credit_model()` function now
takes a `user_id` parameter and uses LaunchDarkly to determine which
credit model to return:
- Flag enabled → `UserCredit` (payment system enabled, no monthly
refills)
- Flag disabled → `BetaUserCredit` (current behavior with monthly
refills)
  
- **Flag Integration**: Added `ENABLE_PLATFORM_PAYMENT` flag and
integrated LaunchDarkly evaluation throughout the credit system

- **API Updates**: All credit-related endpoints now use the
user-specific credit model instead of a global instance

### Frontend
- **Dynamic UI**: Payment-related components (billing page, wallet
refill) now show/hide based on the LaunchDarkly flag
- **Removed Environment Variable**: Replaced
`NEXT_PUBLIC_SHOW_BILLING_PAGE` with runtime flag evaluation

### Testing
- **Test Fixes**: Updated all tests that referenced the removed global
`_user_credit_model` to use proper mocking of the new async function

## Deployment Strategy

This implementation enables a controlled rollout:
1. Deploy with flag disabled (default) - no behavior change for existing
users
2. Enable flag for pilot/beta users via LaunchDarkly dashboard
3. Monitor usage and feedback from pilot users
4. Gradually expand to more users
5. Eventually enable for all users once validated

## Test Plan

- [x] Unit tests pass for credit system components
- [x] Payment UI components show/hide correctly based on flag
- [x] Default behavior (flag disabled) maintains current functionality
- [x] Flag enabled users get payment system without monthly refills
- [x] Admin credit operations work correctly
- [x] Backward compatibility maintained

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-10-17 06:11:39 +00:00
Swifty
3ed1c93ec0 Merge branch 'dev' into swiftyos/secrt-1706-improve-store-search 2025-10-16 15:10:01 +02:00
Swifty
773f545cfd update existing rows when migration is ran 2025-10-16 13:38:01 +02:00
Swifty
84ad4a9f95 updated migration and query 2025-10-16 13:06:47 +02:00
Swifty
8610118ddc ai sucks - fixing 2025-10-16 12:14:26 +02:00
Swifty
ebb4ebb025 include parital types in second place 2025-10-16 12:10:38 +02:00
Swifty
cb532e1c4d update docker file to include partial types 2025-10-16 12:08:04 +02:00
Swifty
794aee25ab add full text search 2025-10-16 11:49:36 +02:00
192 changed files with 8869 additions and 2805 deletions

View File

@@ -12,6 +12,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
- **Infrastructure** - Docker configurations, CI/CD, and development tools
**Primary Languages & Frameworks:**
- **Backend**: Python 3.10-3.13, FastAPI, Prisma ORM, PostgreSQL, RabbitMQ
- **Frontend**: TypeScript, Next.js 15, React, Tailwind CSS, Radix UI
- **Development**: Docker, Poetry, pnpm, Playwright, Storybook
@@ -23,15 +24,17 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
**Always run these commands in the correct directory and in this order:**
1. **Initial Setup** (required once):
```bash
# Clone and enter repository
git clone <repo> && cd AutoGPT
# Start all services (database, redis, rabbitmq, clamav)
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
2. **Backend Setup** (always run before backend development):
```bash
cd autogpt_platform/backend
poetry install # Install dependencies
@@ -48,6 +51,7 @@ This file provides comprehensive onboarding information for GitHub Copilot codin
### Runtime Requirements
**Critical:** Always ensure Docker services are running before starting development:
```bash
cd autogpt_platform && docker compose --profile local up deps --build --detach
```
@@ -58,6 +62,7 @@ cd autogpt_platform && docker compose --profile local up deps --build --detach
### Development Commands
**Backend Development:**
```bash
cd autogpt_platform/backend
poetry run serve # Start development server (port 8000)
@@ -68,6 +73,7 @@ poetry run lint # Lint code (ruff) - run after format
```
**Frontend Development:**
```bash
cd autogpt_platform/frontend
pnpm dev # Start development server (port 3000) - use for active development
@@ -81,23 +87,27 @@ pnpm storybook # Start component development server
### Testing Strategy
**Backend Tests:**
- **Block Tests**: `poetry run pytest backend/blocks/test/test_block.py -xvs` (validates all blocks)
- **Specific Block**: `poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[BlockName]' -xvs`
- **Snapshot Tests**: Use `--snapshot-update` when output changes, always review with `git diff`
**Frontend Tests:**
- **E2E Tests**: Always run `pnpm dev` before `pnpm test` (Playwright requires running instance)
- **Component Tests**: Use Storybook for isolated component development
### Critical Validation Steps
**Before committing changes:**
1. Run `poetry run format` (backend) and `pnpm format` (frontend)
2. Ensure all tests pass in modified areas
3. Verify Docker services are still running
4. Check that database migrations apply cleanly
**Common Issues & Workarounds:**
- **Prisma issues**: Run `poetry run prisma generate` after schema changes
- **Permission errors**: Ensure Docker has proper permissions
- **Port conflicts**: Check the `docker-compose.yml` file for the current list of exposed ports. You can list all mapped ports with:
@@ -108,6 +118,7 @@ pnpm storybook # Start component development server
### Core Architecture
**AutoGPT Platform** (`autogpt_platform/`):
- `backend/` - FastAPI server with async support
- `backend/backend/` - Core API logic
- `backend/blocks/` - Agent execution blocks
@@ -121,6 +132,7 @@ pnpm storybook # Start component development server
- `docker-compose.yml` - Development stack orchestration
**Key Configuration Files:**
- `pyproject.toml` - Python dependencies and tooling
- `package.json` - Node.js dependencies and scripts
- `schema.prisma` - Database schema and migrations
@@ -136,6 +148,7 @@ pnpm storybook # Start component development server
### Development Workflow
**GitHub Actions**: Multiple CI/CD workflows in `.github/workflows/`
- `platform-backend-ci.yml` - Backend testing and validation
- `platform-frontend-ci.yml` - Frontend testing and validation
- `platform-fullstack-ci.yml` - End-to-end integration tests
@@ -146,11 +159,13 @@ pnpm storybook # Start component development server
### Key Source Files
**Backend Entry Points:**
- `backend/backend/server/server.py` - FastAPI application setup
- `backend/backend/data/` - Database models and user management
- `backend/blocks/` - Agent execution blocks and logic
**Frontend Entry Points:**
- `frontend/src/app/layout.tsx` - Root application layout
- `frontend/src/app/page.tsx` - Home page
- `frontend/src/lib/supabase/` - Authentication and database client
@@ -160,6 +175,7 @@ pnpm storybook # Start component development server
### Agent Block System
Agents are built using a visual block-based system where each block performs a single action. Blocks are defined in `backend/blocks/` and must include:
- Block definition with input/output schemas
- Execution logic with proper error handling
- Tests validating functionality
@@ -167,6 +183,7 @@ Agents are built using a visual block-based system where each block performs a s
### Database & ORM
**Prisma ORM** with PostgreSQL backend including pgvector for embeddings:
- Schema in `schema.prisma`
- Migrations in `backend/migrations/`
- Always run `prisma migrate dev` and `prisma generate` after schema changes
@@ -174,13 +191,15 @@ Agents are built using a visual block-based system where each block performs a s
## Environment Configuration
### Configuration Files Priority Order
1. **Backend**: `/backend/.env.default` → `/backend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
2. **Frontend**: `/frontend/.env.default` → `/frontend/.env` (user overrides)
3. **Platform**: `/.env.default` (Supabase/shared) → `/.env` (user overrides)
4. Docker Compose `environment:` sections override file-based config
5. Shell environment variables have highest precedence
### Docker Environment Setup
- All services use hardcoded defaults (no `${VARIABLE}` substitutions)
- The `env_file` directive loads variables INTO containers at runtime
- Backend/Frontend services use YAML anchors for consistent configuration
@@ -189,6 +208,7 @@ Agents are built using a visual block-based system where each block performs a s
## Advanced Development Patterns
### Adding New Blocks
1. Create file in `/backend/backend/blocks/`
2. Inherit from `Block` base class with input/output schemas
3. Implement `run` method with proper error handling
@@ -198,6 +218,7 @@ Agents are built using a visual block-based system where each block performs a s
7. Consider how inputs/outputs connect with other blocks in graph editor
### API Development
1. Update routes in `/backend/backend/server/routers/`
2. Add/update Pydantic models in same directory
3. Write tests alongside route files
@@ -205,21 +226,76 @@ Agents are built using a visual block-based system where each block performs a s
5. Run `poetry run test` to verify changes
### Frontend Development
1. Components in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for component development
4. Test user-facing features with Playwright E2E tests
5. Update protected routes in middleware when needed
**📖 Complete Frontend Guide**: See `autogpt_platform/frontend/CONTRIBUTING.md` and `autogpt_platform/frontend/.cursorrules` for comprehensive patterns and conventions.
**Quick Reference:**
**Component Structure:**
- Separate render logic from data/behavior
- Structure: `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Exception: Small components (3-4 lines of logic) can be inline
- Render-only components can be direct files without folders
**Data Fetching:**
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Generated via Orval from backend OpenAPI spec
- Pattern: `use{Method}{Version}{OperationName}`
- Example: `useGetV2ListLibraryAgents`
- Regenerate with: `pnpm generate:api`
- **Never** use deprecated `BackendAPI` or `src/lib/autogpt-server-api/*`
**Code Conventions:**
- Use function declarations for components and handlers (not arrow functions)
- Only arrow functions for small inline lambdas (map, filter, etc.)
- Components: `PascalCase`, Hooks: `camelCase` with `use` prefix
- No barrel files or `index.ts` re-exports
- Minimal comments (code should be self-documenting)
**Styling:**
- Use Tailwind CSS utilities only
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
- Only use Phosphor Icons (`@phosphor-icons/react`)
- Prefer design tokens over hardcoded values
**Error Handling:**
- Render errors: Use `<ErrorCard />` component
- Mutation errors: Display with toast notifications
- Manual exceptions: Use `Sentry.captureException()`
- Global error boundaries already configured
**Testing:**
- Add/update Storybook stories for UI components (`pnpm storybook`)
- Run Playwright E2E tests with `pnpm test`
- Verify in Chromatic after PR
**Architecture:**
- Default to client components ("use client")
- Server components only for SEO or extreme TTFB needs
- Use React Query for server state (via generated hooks)
- Co-locate UI state in components/hooks
### Security Guidelines
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
- Prevents sensitive data caching in browsers/proxies
- Add new cacheable endpoints to `CACHEABLE_PATHS`
### CI/CD Alignment
The repository has comprehensive CI workflows that test:
- **Backend**: Python 3.11-3.13, services (Redis/RabbitMQ/ClamAV), Prisma migrations, Poetry lock validation
- **Frontend**: Node.js 21, pnpm, Playwright with Docker Compose stack, API schema validation
- **Integration**: Full-stack type checking and E2E testing
@@ -229,6 +305,7 @@ Match these patterns when developing locally - the copilot setup environment mir
## Collaboration with Other AI Assistants
This repository is actively developed with assistance from Claude (via CLAUDE.md files). When working on this codebase:
- Check for existing CLAUDE.md files that provide additional context
- Follow established patterns and conventions already in the codebase
- Maintain consistency with existing code style and architecture
@@ -237,8 +314,9 @@ This repository is actively developed with assistance from Claude (via CLAUDE.md
## Trust These Instructions
These instructions are comprehensive and tested. Only perform additional searches if:
1. Information here is incomplete for your specific task
2. You encounter errors not covered by the workarounds
3. You need to understand implementation details not covered above
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.
For detailed platform development patterns, refer to `autogpt_platform/CLAUDE.md` and `AGENTS.md` in the repository root.

View File

@@ -63,6 +63,9 @@ poetry run pytest path/to/test.py --snapshot-update
# Install dependencies
cd frontend && pnpm i
# Generate API client from OpenAPI spec
pnpm generate:api
# Start development server
pnpm dev
@@ -75,12 +78,23 @@ pnpm storybook
# Build production
pnpm build
# Format and lint
pnpm format
# Type checking
pnpm types
```
We have a components library in autogpt_platform/frontend/src/components/atoms that should be used when adding new pages and components.
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
**Key Frontend Conventions:**
- Separate render logic from data/behavior in components
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Use function declarations (not arrow functions) for components/handlers
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Only use Phosphor Icons
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
## Architecture Overview
@@ -95,11 +109,16 @@ We have a components library in autogpt_platform/frontend/src/components/atoms t
### Frontend Architecture
- **Framework**: Next.js App Router with React Server Components
- **State Management**: React hooks + Supabase client for real-time updates
- **Framework**: Next.js 15 App Router (client-first approach)
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
- **State Management**: React Query for server state, co-located UI state in components/hooks
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
- **Workflow Builder**: Visual graph editor using @xyflow/react
- **UI Components**: Radix UI primitives with Tailwind CSS styling
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
- **Icons**: Phosphor Icons only
- **Feature Flags**: LaunchDarkly integration
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
- **Testing**: Playwright for E2E, Storybook for component development
### Key Concepts
@@ -153,6 +172,7 @@ Key models (defined in `/backend/schema.prisma`):
**Adding a new block:**
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
- Provider configuration with `ProviderBuilder`
- Block schema definition
- Authentication (API keys, OAuth, webhooks)
@@ -160,6 +180,7 @@ Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-
- File organization
Quick steps:
1. Create new file in `/backend/backend/blocks/`
2. Configure provider using `ProviderBuilder` in `_config.py`
3. Inherit from `Block` base class
@@ -180,10 +201,20 @@ ex: do the inputs and outputs tie well together?
**Frontend feature development:**
1. Components go in `/frontend/src/components/`
2. Use existing UI components from `/frontend/src/components/ui/`
3. Add Storybook stories for new components
4. Test with Playwright if user-facing
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
- Add `usePageName.ts` hook for logic
- Put sub-components in local `components/` folder
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
- Use design system components from `src/components/` (atoms, molecules, organisms)
- Never use `src/components/__legacy__/*`
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
- Regenerate with `pnpm generate:api`
- Pattern: `use{Method}{Version}{OperationName}`
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
### Security Implementation

View File

@@ -8,6 +8,11 @@ start-core:
stop-core:
docker compose stop deps
reset-db:
rm -rf db/docker/volumes/db/data
cd backend && poetry run prisma migrate deploy
cd backend && poetry run prisma generate
# View logs for core services
logs-core:
docker compose logs -f deps
@@ -35,13 +40,18 @@ run-backend:
run-frontend:
cd frontend && pnpm dev
test-data:
cd backend && poetry run python test/test_data_creator.py
help:
@echo "Usage: make <target>"
@echo "Targets:"
@echo " start-core - Start just the core services (Supabase, Redis, RabbitMQ) in background"
@echo " stop-core - Stop the core services"
@echo " reset-db - Reset the database by deleting the volume"
@echo " logs-core - Tail the logs for core services"
@echo " format - Format & lint backend (Python) and frontend (TypeScript) code"
@echo " migrate - Run backend database migrations"
@echo " run-backend - Run the backend FastAPI server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " run-frontend - Run the frontend Next.js development server"
@echo " test-data - Run the test data creator"

View File

@@ -94,42 +94,36 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
config = LoggingConfig()
log_handlers: list[logging.Handler] = []
structured_logging = config.enable_cloud_logging or force_cloud_logging
# Console output handlers
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
if not structured_logging:
stdout = logging.StreamHandler(stream=sys.stdout)
stdout.setLevel(config.level)
stdout.addFilter(BelowLevelFilter(logging.WARNING))
if config.level == logging.DEBUG:
stdout.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stdout.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
stderr = logging.StreamHandler()
stderr.setLevel(logging.WARNING)
if config.level == logging.DEBUG:
stderr.setFormatter(AGPTFormatter(DEBUG_LOG_FORMAT))
else:
stderr.setFormatter(AGPTFormatter(SIMPLE_LOG_FORMAT))
log_handlers += [stdout, stderr]
log_handlers += [stdout, stderr]
# Cloud logging setup
if config.enable_cloud_logging or force_cloud_logging:
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from google.cloud.logging_v2.handlers.transports import (
BackgroundThreadTransport,
)
else:
# Use Google Cloud Structured Log Handler. Log entries are printed to stdout
# in a JSON format which is automatically picked up by Google Cloud Logging.
from google.cloud.logging.handlers import StructuredLogHandler
client = google.cloud.logging.Client()
# Use BackgroundThreadTransport to prevent blocking the main thread
# and deadlocks when gRPC calls to Google Cloud Logging hang
cloud_handler = CloudLoggingHandler(
client,
name="autogpt_logs",
transport=BackgroundThreadTransport,
)
cloud_handler.setLevel(config.level)
log_handlers.append(cloud_handler)
structured_log_handler = StructuredLogHandler(stream=sys.stdout)
structured_log_handler.setLevel(config.level)
log_handlers.append(structured_log_handler)
# File logging setup
if config.enable_file_logging:
@@ -185,7 +179,13 @@ def configure_logging(force_cloud_logging: bool = False) -> None:
# Configure the root logger
logging.basicConfig(
format=DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT,
format=(
"%(levelname)s %(message)s"
if structured_logging
else (
DEBUG_LOG_FORMAT if config.level == logging.DEBUG else SIMPLE_LOG_FORMAT
)
),
level=config.level,
handlers=log_handlers,
)

View File

@@ -1,339 +0,0 @@
import asyncio
import inspect
import logging
import threading
import time
from functools import wraps
from typing import (
Any,
Callable,
ParamSpec,
Protocol,
TypeVar,
cast,
runtime_checkable,
)
P = ParamSpec("P")
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[Any, ...]:
"""
Convert args and kwargs into a hashable cache key.
Handles unhashable types like dict, list, set by converting them to
their sorted string representations.
"""
def make_hashable(obj: Any) -> Any:
"""Recursively convert an object to a hashable representation."""
if isinstance(obj, dict):
# Sort dict items to ensure consistent ordering
return (
"__dict__",
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
)
elif isinstance(obj, (list, tuple)):
return ("__list__", tuple(make_hashable(item) for item in obj))
elif isinstance(obj, set):
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
elif hasattr(obj, "__dict__"):
# Handle objects with __dict__ attribute
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
else:
# For basic hashable types (str, int, bool, None, etc.)
try:
hash(obj)
return obj
except TypeError:
# Fallback: convert to string representation
return ("__str__", str(obj))
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
return (hashable_args, hashable_kwargs)
@runtime_checkable
class CachedFunction(Protocol[P, R_co]):
"""Protocol for cached functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
return False
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
Args:
func: The function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire
Returns:
Decorated function or decorator
Example:
@cache() # Default: maxsize=128, no TTL
def expensive_sync_operation(param: str) -> dict:
return {"result": param}
@cache() # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
@cache(maxsize=1000, ttl_seconds=300) # Custom maxsize and TTL
def another_operation(param: str) -> dict:
return {"result": param}
"""
def decorator(target_func):
# Cache storage and per-event-loop locks
cache_storage = {}
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
if inspect.iscoroutinefunction(target_func):
def _get_cache_lock():
"""Get or create an asyncio.Lock for the current event loop."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No event loop, use None as default key
loop = None
if loop not in _event_loop_locks:
return _event_loop_locks.setdefault(loop, asyncio.Lock())
return _event_loop_locks[loop]
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
async with _get_cache_lock():
# Double-check: another coroutine might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
wrapper = async_wrapper
else:
# Sync function with threading.Lock
cache_lock = threading.Lock()
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
wrapper = sync_wrapper
# Add cache management methods
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
setattr(wrapper, "cache_delete", cache_delete)
return cast(CachedFunction, wrapper)
return decorator
def thread_cached(func):
"""
Thread-local cache decorator for both sync and async functions.
Each thread gets its own cache, which is useful for request-scoped caching
in web applications where you want to cache within a single request but
not across requests.
Args:
func: The function to cache
Returns:
Decorated function with thread-local caching
Example:
@thread_cached
def expensive_operation(param: str) -> dict:
return {"result": param}
@thread_cached # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
"""
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = await func(*args, **kwargs)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
"""Clear thread-local cache for a function."""
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -47,6 +47,7 @@ RUN poetry install --no-ansi --no-root
# Generate Prisma client
COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
RUN poetry run prisma generate
FROM debian:13-slim AS server_dependencies
@@ -92,6 +93,7 @@ FROM server_dependencies AS migrate
# Migration stage only needs schema and migrations - much lighter than full backend
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
FROM server_dependencies AS server

View File

@@ -5,7 +5,7 @@ import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar
from autogpt_libs.utils.cache import cached
from backend.util.cache import cached
logger = logging.getLogger(__name__)
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
@cached()
@cached(ttl_seconds=3600)
def load_all_blocks() -> dict[str, type["Block"]]:
from backend.data.block import Block
from backend.util.settings import Config

View File

@@ -4,13 +4,13 @@ import mimetypes
from pathlib import Path
from typing import Any
import aiohttp
import discord
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, SchemaField
from backend.util.file import store_media_file
from backend.util.request import Requests
from backend.util.type import MediaFileType
from ._auth import (
@@ -114,10 +114,9 @@ class ReadDiscordMessagesBlock(Block):
if message.attachments:
attachment = message.attachments[0] # Process the first attachment
if attachment.filename.endswith((".txt", ".py")):
async with aiohttp.ClientSession() as session:
async with session.get(attachment.url) as response:
file_content = response.text()
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
response = await Requests().get(attachment.url)
file_content = response.text()
self.output_data += f"\n\nFile from user: {attachment.filename}\nContent: {file_content}"
await client.close()
@@ -699,16 +698,15 @@ class SendDiscordFileBlock(Block):
elif file.startswith(("http://", "https://")):
# URL - download the file
async with aiohttp.ClientSession() as session:
async with session.get(file) as response:
file_bytes = await response.read()
response = await Requests().get(file)
file_bytes = response.content
# Try to get filename from URL if not provided
if not filename:
from urllib.parse import urlparse
# Try to get filename from URL if not provided
if not filename:
from urllib.parse import urlparse
path = urlparse(file).path
detected_filename = Path(path).name or "download"
path = urlparse(file).path
detected_filename = Path(path).name or "download"
else:
# Local file path - read from stored media file
# This would be a path from a previous block's output

View File

@@ -2,7 +2,7 @@ from typing import Any
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.json import json
from backend.util.json import loads
class StepThroughItemsBlock(Block):
@@ -68,7 +68,7 @@ class StepThroughItemsBlock(Block):
raise ValueError(
f"Input too large: {len(data)} bytes > {MAX_ITEM_SIZE} bytes"
)
items = json.loads(data)
items = loads(data)
else:
items = data

View File

@@ -62,10 +62,10 @@ TEST_CREDENTIALS_OAUTH = OAuth2Credentials(
title="Mock Linear API key",
username="mock-linear-username",
access_token=SecretStr("mock-linear-access-token"),
access_token_expires_at=None,
access_token_expires_at=1672531200, # Mock expiration time for short-lived token
refresh_token=SecretStr("mock-linear-refresh-token"),
refresh_token_expires_at=None,
scopes=["mock-linear-scopes"],
scopes=["read", "write"],
)
TEST_CREDENTIALS_API_KEY = APIKeyCredentials(

View File

@@ -2,7 +2,9 @@
Linear OAuth handler implementation.
"""
import base64
import json
import time
from typing import Optional
from urllib.parse import urlencode
@@ -38,8 +40,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.auth_base_url = "https://linear.app/oauth/authorize"
self.token_url = "https://api.linear.app/oauth/token" # Correct token URL
self.token_url = "https://api.linear.app/oauth/token"
self.revoke_url = "https://api.linear.app/oauth/revoke"
self.migrate_url = "https://api.linear.app/oauth/migrate_old_token"
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
@@ -82,19 +85,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
return True # Linear doesn't return JSON on successful revoke
async def migrate_old_token(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
"""
Migrate an old long-lived token to a new short-lived token with refresh token.
This uses Linear's /oauth/migrate_old_token endpoint to exchange current
long-lived tokens for short-lived tokens with refresh tokens without
requiring users to re-authorize.
"""
if not credentials.access_token:
raise ValueError("No access token to migrate")
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
}
headers = {
"Authorization": f"Bearer {credentials.access_token.get_secret_value()}",
"Content-Type": "application/x-www-form-urlencoded",
}
response = await Requests().post(
self.migrate_url, data=request_body, headers=headers
)
if not response.ok:
try:
error_data = response.json()
error_message = error_data.get("error", "Unknown error")
error_description = error_data.get("error_description", "")
if error_description:
error_message = f"{error_message}: {error_description}"
except json.JSONDecodeError:
error_message = response.text
raise LinearAPIException(
f"Failed to migrate Linear token ({response.status}): {error_message}",
response.status,
)
token_data = response.json()
# Extract token expiration
now = int(time.time())
expires_in = token_data.get("expires_in")
access_token_expires_at = None
if expires_in:
access_token_expires_at = now + expires_in
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=credentials.title,
username=credentials.username,
access_token=token_data["access_token"],
scopes=credentials.scopes, # Preserve original scopes
refresh_token=token_data.get("refresh_token"),
access_token_expires_at=access_token_expires_at,
refresh_token_expires_at=None,
)
new_credentials.id = credentials.id
return new_credentials
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError(
"No refresh token available."
) # Linear uses non-expiring tokens
"No refresh token available. Token may need to be migrated to the new refresh token system."
)
return await self._request_tokens(
{
"refresh_token": credentials.refresh_token.get_secret_value(),
"grant_type": "refresh_token",
}
},
current_credentials=credentials,
)
async def _request_tokens(
@@ -102,16 +170,33 @@ class LinearOAuthHandler(BaseOAuthHandler):
params: dict[str, str],
current_credentials: Optional[OAuth2Credentials] = None,
) -> OAuth2Credentials:
# Determine if this is a refresh token request
is_refresh = params.get("grant_type") == "refresh_token"
# Build request body with appropriate grant_type
request_body = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": "authorization_code", # Ensure grant_type is correct
**params,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded"
} # Correct header for token request
# Set default grant_type if not provided
if "grant_type" not in request_body:
request_body["grant_type"] = "authorization_code"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
# For refresh token requests, support HTTP Basic Authentication as recommended
if is_refresh:
# Option 1: Use HTTP Basic Auth (preferred by Linear)
client_credentials = f"{self.client_id}:{self.client_secret}"
encoded_credentials = base64.b64encode(client_credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
# Remove client credentials from body when using Basic Auth
request_body.pop("client_id", None)
request_body.pop("client_secret", None)
response = await Requests().post(
self.token_url, data=request_body, headers=headers
)
@@ -120,6 +205,9 @@ class LinearOAuthHandler(BaseOAuthHandler):
try:
error_data = response.json()
error_message = error_data.get("error", "Unknown error")
error_description = error_data.get("error_description", "")
if error_description:
error_message = f"{error_message}: {error_description}"
except json.JSONDecodeError:
error_message = response.text
raise LinearAPIException(
@@ -129,27 +217,84 @@ class LinearOAuthHandler(BaseOAuthHandler):
token_data = response.json()
# Note: Linear access tokens do not expire, so we set expires_at to None
# Extract token expiration if provided (for new refresh token implementation)
now = int(time.time())
expires_in = token_data.get("expires_in")
access_token_expires_at = None
if expires_in:
access_token_expires_at = now + expires_in
# Get username - preserve from current credentials if refreshing
username = None
if current_credentials and is_refresh:
username = current_credentials.username
elif "user" in token_data:
username = token_data["user"].get("name", "Unknown User")
else:
# Fetch username using the access token
username = await self._request_username(token_data["access_token"])
new_credentials = OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=current_credentials.title if current_credentials else None,
username=token_data.get("user", {}).get(
"name", "Unknown User"
), # extract name or set appropriate
username=username or "Unknown User",
access_token=token_data["access_token"],
scopes=token_data["scope"].split(
","
), # Linear returns comma-separated scopes
refresh_token=token_data.get(
"refresh_token"
), # Linear uses non-expiring tokens so this might be null
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=(
token_data["scope"].split(",")
if "scope" in token_data
else (current_credentials.scopes if current_credentials else [])
),
refresh_token=token_data.get("refresh_token"),
access_token_expires_at=access_token_expires_at,
refresh_token_expires_at=None, # Linear doesn't provide refresh token expiration
)
if current_credentials:
new_credentials.id = current_credentials.id
return new_credentials
async def get_access_token(self, credentials: OAuth2Credentials) -> str:
"""
Returns a valid access token, handling migration and refresh as needed.
This overrides the base implementation to handle Linear's token migration
from old long-lived tokens to new short-lived tokens with refresh tokens.
"""
# If token has no expiration and no refresh token, it might be an old token
# that needs migration
if (
credentials.access_token_expires_at is None
and credentials.refresh_token is None
):
try:
# Attempt to migrate the old token
migrated_credentials = await self.migrate_old_token(credentials)
# Update the credentials store would need to be handled by the caller
# For now, use the migrated credentials for this request
credentials = migrated_credentials
except LinearAPIException:
# Migration failed, try to use the old token as-is
# This maintains backward compatibility
pass
# Use the standard refresh logic from the base class
if self.needs_refresh(credentials):
credentials = await self.refresh_tokens(credentials)
return credentials.access_token.get_secret_value()
def needs_migration(self, credentials: OAuth2Credentials) -> bool:
"""
Check if credentials represent an old long-lived token that needs migration.
Old tokens have no expiration time and no refresh token.
"""
return (
credentials.access_token_expires_at is None
and credentials.refresh_token is None
)
async def _request_username(self, access_token: str) -> Optional[str]:
# Use the LinearClient to fetch user details using GraphQL
from ._api import LinearClient

View File

@@ -104,8 +104,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# AI/ML API models
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
@@ -224,12 +222,6 @@ MODEL_METADATA = {
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 64000
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-sonnet-20241022
LlmModel.CLAUDE_3_5_HAIKU: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-haiku-20241022
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
"anthropic", 200000, 4096
), # claude-3-haiku-20240307
@@ -1562,7 +1554,9 @@ class AIConversationBlock(AIBlockBase):
("prompt", list),
],
test_mock={
"llm_call": lambda *args, **kwargs: "The 2020 World Series was played at Globe Life Field in Arlington, Texas."
"llm_call": lambda *args, **kwargs: dict(
response="The 2020 World Series was played at Globe Life Field in Arlington, Texas."
)
},
)
@@ -1591,7 +1585,7 @@ class AIConversationBlock(AIBlockBase):
),
credentials=credentials,
)
yield "response", response
yield "response", response["response"]
yield "prompt", self.prompt

View File

@@ -1,7 +1,5 @@
import asyncio
import logging
import urllib.parse
import urllib.request
from datetime import datetime, timedelta, timezone
from typing import Any
@@ -10,6 +8,7 @@ import pydantic
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import Requests
class RSSEntry(pydantic.BaseModel):
@@ -103,35 +102,29 @@ class ReadRSSFeedBlock(Block):
)
@staticmethod
def parse_feed(url: str) -> dict[str, Any]:
async def parse_feed(url: str) -> dict[str, Any]:
# Security fix: Add protection against memory exhaustion attacks
MAX_FEED_SIZE = 10 * 1024 * 1024 # 10MB limit for RSS feeds
# Validate URL
parsed_url = urllib.parse.urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed_url.scheme}")
# Download with size limit
# Download feed content with size limit
try:
with urllib.request.urlopen(url, timeout=30) as response:
# Check content length if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_FEED_SIZE:
raise ValueError(
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
)
response = await Requests(raise_for_status=True).get(url)
# Read with size limit
content = response.read(MAX_FEED_SIZE + 1)
if len(content) > MAX_FEED_SIZE:
raise ValueError(
f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit"
)
# Check content length if available
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > MAX_FEED_SIZE:
raise ValueError(
f"Feed too large: {content_length} bytes exceeds {MAX_FEED_SIZE} limit"
)
# Parse with feedparser using the validated content
# feedparser has built-in protection against XML attacks
return feedparser.parse(content) # type: ignore
# Get content with size limit
content = response.content
if len(content) > MAX_FEED_SIZE:
raise ValueError(f"Feed too large: exceeds {MAX_FEED_SIZE} byte limit")
# Parse with feedparser using the validated content
# feedparser has built-in protection against XML attacks
return feedparser.parse(content) # type: ignore
except Exception as e:
# Log error and return empty feed
logging.warning(f"Failed to parse RSS feed from {url}: {e}")
@@ -145,7 +138,7 @@ class ReadRSSFeedBlock(Block):
while keep_going:
keep_going = input_data.run_continuously
feed = self.parse_feed(input_data.rss_url)
feed = await self.parse_feed(input_data.rss_url)
all_entries = []
for entry in feed["entries"]:

View File

@@ -1,6 +1,7 @@
import logging
import signal
import threading
import warnings
from contextlib import contextmanager
from enum import Enum
@@ -26,6 +27,13 @@ from backend.sdk import (
SchemaField,
)
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
warnings.filterwarnings(
"ignore",
message="coroutine 'close_litellm_async_clients' was never awaited",
category=RuntimeWarning,
)
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers

View File

@@ -362,7 +362,7 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 1
# Check output
assert outputs["response"] == {"response": "AI response to conversation"}
assert outputs["response"] == "AI response to conversation"
@pytest.mark.asyncio
async def test_ai_list_generator_with_retries(self):

View File

@@ -1,6 +1,7 @@
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api._api import YouTubeTranscriptApi
from youtube_transcript_api._errors import NoTranscriptFound
from youtube_transcript_api._transcripts import FetchedTranscript
from youtube_transcript_api.formatters import TextFormatter
@@ -64,7 +65,29 @@ class TranscribeYoutubeVideoBlock(Block):
@staticmethod
def get_transcript(video_id: str) -> FetchedTranscript:
return YouTubeTranscriptApi().fetch(video_id=video_id)
"""
Get transcript for a video, preferring English but falling back to any available language.
:param video_id: The YouTube video ID
:return: The fetched transcript
:raises: Any exception except NoTranscriptFound for requested languages
"""
api = YouTubeTranscriptApi()
try:
# Try to get English transcript first (default behavior)
return api.fetch(video_id=video_id)
except NoTranscriptFound:
# If English is not available, get the first available transcript
transcript_list = api.list(video_id)
# Try manually created transcripts first, then generated ones
available_transcripts = list(
transcript_list._manually_created_transcripts.values()
) + list(transcript_list._generated_transcripts.values())
if available_transcripts:
# Fetch the first available transcript
return available_transcripts[0].fetch()
# If no transcripts at all, re-raise the original error
raise
@staticmethod
def format_transcript(transcript: FetchedTranscript) -> str:

View File

@@ -45,9 +45,6 @@ class MainApp(AppProcess):
app.main(silent=True)
def cleanup(self):
pass
@click.group()
def main():

View File

@@ -20,7 +20,6 @@ from typing import (
import jsonref
import jsonschema
from autogpt_libs.utils.cache import cached
from prisma.models import AgentBlock
from prisma.types import AgentBlockCreateInput
from pydantic import BaseModel
@@ -28,6 +27,7 @@ from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.cache import cached
from backend.util.settings import Config
from .model import (
@@ -722,7 +722,7 @@ def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
return cls() if cls else None
@cached()
@cached(ttl_seconds=3600)
def get_webhook_block_ids() -> Sequence[str]:
return [
id
@@ -731,7 +731,7 @@ def get_webhook_block_ids() -> Sequence[str]:
]
@cached()
@cached(ttl_seconds=3600)
def get_io_block_ids() -> Sequence[str]:
return [
id

View File

@@ -1,7 +1,11 @@
from typing import Type
from backend.blocks.ai_music_generator import AIMusicGeneratorBlock
from backend.blocks.ai_shortform_video_block import AIShortformVideoCreatorBlock
from backend.blocks.ai_shortform_video_block import (
AIAdMakerVideoCreatorBlock,
AIScreenshotToVideoAdBlock,
AIShortformVideoCreatorBlock,
)
from backend.blocks.apollo.organization import SearchOrganizationsBlock
from backend.blocks.apollo.people import SearchPeopleBlock
from backend.blocks.apollo.person import GetPersonDetailBlock
@@ -72,8 +76,6 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.CLAUDE_4_5_HAIKU: 4,
LlmModel.CLAUDE_4_5_SONNET: 9,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
LlmModel.CLAUDE_3_HAIKU: 1,
LlmModel.AIML_API_QWEN2_5_72B: 1,
LlmModel.AIML_API_LLAMA3_1_70B: 1,
@@ -323,7 +325,31 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
],
AIShortformVideoCreatorBlock: [
BlockCost(
cost_amount=50,
cost_amount=307,
cost_filter={
"credentials": {
"id": revid_credentials.id,
"provider": revid_credentials.provider,
"type": revid_credentials.type,
}
},
)
],
AIAdMakerVideoCreatorBlock: [
BlockCost(
cost_amount=714,
cost_filter={
"credentials": {
"id": revid_credentials.id,
"provider": revid_credentials.provider,
"type": revid_credentials.type,
}
},
)
],
AIScreenshotToVideoAdBlock: [
BlockCost(
cost_amount=612,
cost_filter={
"credentials": {
"id": revid_credentials.id,

View File

@@ -5,7 +5,6 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, cast
import stripe
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
@@ -13,16 +12,12 @@ from prisma.enums import (
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import (
CreditRefundRequestCreateInput,
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import query_raw_with_schema
from backend.data.includes import MAX_CREDIT_REFUND_REQUESTS_FETCH
from backend.data.model import (
AutoTopUpConfig,
@@ -36,7 +31,8 @@ from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications.notifications import queue_notification_async
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.json import SafeJson, dumps
from backend.util.models import Pagination
from backend.util.retry import func_retry
from backend.util.settings import Settings
@@ -49,6 +45,10 @@ stripe.api_key = settings.secrets.stripe_api_key
logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
# Constants for test compatibility
POSTGRES_INT_MAX = 2147483647
POSTGRES_INT_MIN = -2147483648
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
@@ -139,14 +139,20 @@ class UserCreditBase(ABC):
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
async def onboarding_reward(
self, user_id: str, credits: int, step: OnboardingStep
) -> bool:
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
credits (int): The amount to reward.
step (OnboardingStep): The onboarding step.
Returns:
bool: True if rewarded, False if already rewarded.
"""
pass
@@ -236,6 +242,12 @@ class UserCreditBase(ABC):
"""
Returns the current balance of the user & the latest balance snapshot time.
"""
# Check UserBalance first for efficiency and consistency
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
if user_balance:
return user_balance.balance, user_balance.updatedAt
# Fallback to transaction history computation if UserBalance doesn't exist
top_time = self.time_now()
snapshot = await CreditTransaction.prisma().find_first(
where={
@@ -250,72 +262,86 @@ class UserCreditBase(ABC):
snapshot_balance = snapshot.runningBalance or 0 if snapshot else 0
snapshot_time = snapshot.createdAt if snapshot else datetime_min
# Get transactions after the snapshot, this should not exist, but just in case.
transactions = await CreditTransaction.prisma().group_by(
by=["userId"],
sum={"amount": True},
max={"createdAt": True},
where={
"userId": user_id,
"createdAt": {
"gt": snapshot_time,
"lte": top_time,
},
"isActive": True,
},
)
transaction_balance = (
int(transactions[0].get("_sum", {}).get("amount", 0) + snapshot_balance)
if transactions
else snapshot_balance
)
transaction_time = (
datetime.fromisoformat(
str(transactions[0].get("_max", {}).get("createdAt", datetime_min))
)
if transactions
else snapshot_time
)
return transaction_balance, transaction_time
return snapshot_balance, snapshot_time
@func_retry
async def _enable_transaction(
self,
transaction_key: str,
user_id: str,
metadata: Json,
metadata: SafeJson,
new_transaction_key: str | None = None,
):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
# First check if transaction exists and is inactive (safety check)
transaction = await CreditTransaction.prisma().find_first(
where={
"transactionKey": transaction_key,
"userId": user_id,
"isActive": False,
}
)
if transaction.isActive:
return
if not transaction:
# Transaction doesn't exist or is already active, return early
return None
async with db.locked_transaction(f"usr_trx_{user_id}"):
transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": transaction_key, "userId": user_id}
# Atomic operation to enable transaction and update user balance using UserBalance
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$2::text as userId,
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $2 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $2
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
transaction_check AS (
SELECT * FROM {schema_prefix}"CreditTransaction"
WHERE "transactionKey" = $1 AND "userId" = $2 AND "isActive" = false
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$2::text,
user_balance_lock.balance + transaction_check.amount,
CURRENT_TIMESTAMP
FROM user_balance_lock, transaction_check
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_update AS (
UPDATE {schema_prefix}"CreditTransaction"
SET "transactionKey" = COALESCE($4, $1),
"isActive" = true,
"runningBalance" = balance_update.balance,
"createdAt" = balance_update."updatedAt",
"metadata" = $3::jsonb
FROM balance_update, transaction_check
WHERE {schema_prefix}"CreditTransaction"."transactionKey" = transaction_check."transactionKey"
AND {schema_prefix}"CreditTransaction"."userId" = transaction_check."userId"
RETURNING {schema_prefix}"CreditTransaction"."runningBalance"
)
if transaction.isActive:
return
SELECT "runningBalance" as balance FROM transaction_update;
""",
transaction_key, # $1
user_id, # $2
dumps(metadata.data), # $3 - use pre-serialized JSON string for JSONB
new_transaction_key, # $4
)
user_balance, _ = await self._get_credits(user_id)
await CreditTransaction.prisma().update(
where={
"creditTransactionIdentifier": {
"transactionKey": transaction_key,
"userId": user_id,
}
},
data={
"transactionKey": new_transaction_key or transaction_key,
"isActive": True,
"runningBalance": user_balance + transaction.amount,
"createdAt": self.time_now(),
"metadata": metadata,
},
)
if result:
# UserBalance is already updated by the CTE
return result[0]["balance"]
async def _add_transaction(
self,
@@ -326,12 +352,54 @@ class UserCreditBase(ABC):
transaction_key: str | None = None,
ceiling_balance: int | None = None,
fail_insufficient_credits: bool = True,
metadata: Json = SafeJson({}),
metadata: SafeJson = SafeJson({}),
) -> tuple[int, str]:
"""
Add a new transaction for the user.
This is the only method that should be used to add a new transaction.
ATOMIC OPERATION DESIGN DECISION:
================================
This method uses PostgreSQL row-level locking (FOR UPDATE) for atomic credit operations.
After extensive analysis of concurrency patterns and correctness requirements, we determined
that the FOR UPDATE approach is necessary despite the latency overhead.
WHY FOR UPDATE LOCKING IS REQUIRED:
----------------------------------
1. **Data Consistency**: Credit operations must be ACID-compliant. The balance check,
calculation, and update must be atomic to prevent race conditions where:
- Multiple spend operations could exceed available balance
- Lost update problems could occur with concurrent top-ups
- Refunds could create negative balances incorrectly
2. **Serializability**: FOR UPDATE ensures operations are serialized at the database level,
guaranteeing that each transaction sees a consistent view of the balance before applying changes.
3. **Correctness Over Performance**: Financial operations require absolute correctness.
The ~10-50ms latency increase from row locking is acceptable for the guarantee that
no user will ever have an incorrect balance due to race conditions.
4. **PostgreSQL Optimization**: Modern PostgreSQL versions optimize row locks efficiently.
The performance cost is minimal compared to the complexity and risk of lock-free approaches.
ALTERNATIVES CONSIDERED AND REJECTED:
------------------------------------
- **Optimistic Concurrency**: Using version numbers or timestamps would require complex
retry logic and could still fail under high contention scenarios.
- **Application-Level Locking**: Redis locks or similar would add network overhead and
single points of failure while being less reliable than database locks.
- **Event Sourcing**: Would require complete architectural changes and eventual consistency
models that don't fit our real-time balance requirements.
PERFORMANCE CHARACTERISTICS:
---------------------------
- Single user operations: 10-50ms latency (acceptable for financial operations)
- Concurrent operations on same user: Serialized (prevents data corruption)
- Concurrent operations on different users: Fully parallel (no blocking)
This design prioritizes correctness and data integrity over raw performance,
which is the appropriate choice for a credit/payment system.
Args:
user_id (str): The user ID.
amount (int): The amount of credits to add.
@@ -345,40 +413,142 @@ class UserCreditBase(ABC):
Returns:
tuple[int, str]: The new balance & the transaction key.
"""
async with db.locked_transaction(f"usr_trx_{user_id}"):
# Get latest balance snapshot
user_balance, _ = await self._get_credits(user_id)
if ceiling_balance and amount > 0 and user_balance >= ceiling_balance:
# Quick validation for ceiling balance to avoid unnecessary database operations
if ceiling_balance and amount > 0:
current_balance, _ = await self._get_credits(user_id)
if current_balance >= ceiling_balance:
raise ValueError(
f"You already have enough balance of ${user_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}"
)
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
# Single unified atomic operation for all transaction types using UserBalance
try:
result = await query_raw_with_schema(
"""
WITH user_balance_lock AS (
SELECT
$1::text as userId,
-- CRITICAL: FOR UPDATE lock prevents concurrent modifications to the same user's balance
-- This ensures atomic read-modify-write operations and prevents race conditions
COALESCE(
(SELECT balance FROM {schema_prefix}"UserBalance" WHERE "userId" = $1 FOR UPDATE),
-- Fallback: compute balance from transaction history if UserBalance doesn't exist
(SELECT COALESCE(ct."runningBalance", 0)
FROM {schema_prefix}"CreditTransaction" ct
WHERE ct."userId" = $1
AND ct."isActive" = true
AND ct."runningBalance" IS NOT NULL
ORDER BY ct."createdAt" DESC
LIMIT 1),
0
) as balance
),
balance_update AS (
INSERT INTO {schema_prefix}"UserBalance" ("userId", "balance", "updatedAt")
SELECT
$1::text,
CASE
-- For inactive transactions: Don't update balance
WHEN $5::boolean = false THEN user_balance_lock.balance
-- For ceiling balance (amount > 0): Apply ceiling
WHEN $2 > 0 AND $7::int IS NOT NULL AND user_balance_lock.balance > $7::int - $2 THEN $7::int
-- For regular operations: Apply with overflow/underflow protection
WHEN user_balance_lock.balance + $2 > $6::int THEN $6::int
WHEN user_balance_lock.balance + $2 < $10::int THEN $10::int
ELSE user_balance_lock.balance + $2
END,
CURRENT_TIMESTAMP
FROM user_balance_lock
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
)
ON CONFLICT ("userId") DO UPDATE SET
"balance" = EXCLUDED."balance",
"updatedAt" = EXCLUDED."updatedAt"
RETURNING "balance", "updatedAt"
),
transaction_insert AS (
INSERT INTO {schema_prefix}"CreditTransaction" (
"userId", "amount", "type", "runningBalance",
"metadata", "isActive", "createdAt", "transactionKey"
)
SELECT
$1::text,
$2::int,
$3::text::{schema_prefix}"CreditTransactionType",
CASE
-- For inactive transactions: Set runningBalance to original balance (don't apply the change yet)
WHEN $5::boolean = false THEN user_balance_lock.balance
ELSE COALESCE(balance_update.balance, user_balance_lock.balance)
END,
$4::jsonb,
$5::boolean,
COALESCE(balance_update."updatedAt", CURRENT_TIMESTAMP),
COALESCE($9, gen_random_uuid()::text)
FROM user_balance_lock
LEFT JOIN balance_update ON true
WHERE (
$5::boolean = false OR -- Allow inactive transactions
$2 >= 0 OR -- Allow positive amounts (top-ups, grants)
$8::boolean = false OR -- Allow when insufficient balance check is disabled
user_balance_lock.balance + $2 >= 0 -- Allow spending only when sufficient balance
)
RETURNING "runningBalance", "transactionKey"
)
SELECT "runningBalance" as balance, "transactionKey" FROM transaction_insert;
""",
user_id, # $1
amount, # $2
transaction_type.value, # $3
dumps(metadata.data), # $4 - use pre-serialized JSON string for JSONB
is_active, # $5
POSTGRES_INT_MAX, # $6 - overflow protection
ceiling_balance, # $7 - ceiling balance (nullable)
fail_insufficient_credits, # $8 - check balance for spending
transaction_key, # $9 - transaction key (nullable)
POSTGRES_INT_MIN, # $10 - underflow protection
)
except Exception as e:
# Convert raw SQL unique constraint violations to UniqueViolationError
# for consistent exception handling throughout the codebase
error_str = str(e).lower()
if (
"already exists" in error_str
or "duplicate key" in error_str
or "unique constraint" in error_str
):
# Extract table and constraint info for better error messages
# Re-raise as a UniqueViolationError but with proper format
# Create a minimal data structure that the error constructor expects
raise UniqueViolationError({"error": str(e), "user_facing_error": {}})
# For any other error, re-raise as-is
raise
amount = min(-user_balance, 0)
if result:
new_balance, tx_key = result[0]["balance"], result[0]["transactionKey"]
# UserBalance is already updated by the CTE
return new_balance, tx_key
# Create the transaction
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)
return user_balance + amount, tx.transactionKey
# If no result, either user doesn't exist or insufficient balance
user = await User.prisma().find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User {user_id} not found")
# Must be insufficient balance for spending operation
if amount < 0 and fail_insufficient_credits:
current_balance, _ = await self._get_credits(user_id)
raise InsufficientBalanceError(
message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=current_balance,
amount=amount,
)
# Unexpected case
raise ValueError(f"Transaction failed for user {user_id}, amount {amount}")
class UserCredit(UserCreditBase):
@@ -453,9 +623,10 @@ class UserCredit(UserCreditBase):
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
return True
except UniqueViolationError:
# Already rewarded for this step
pass
# User already received this reward
return False
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -644,7 +815,7 @@ class UserCredit(UserCreditBase):
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata["reason"]:
if not metadata.get("reason"):
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
@@ -974,8 +1145,8 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs) -> bool:
return True
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
@@ -993,14 +1164,31 @@ class DisabledUserCredit(UserCreditBase):
pass
def get_user_credit_model() -> UserCreditBase:
async def get_user_credit_model(user_id: str) -> UserCreditBase:
"""
Get the credit model for a user, considering LaunchDarkly flags.
Args:
user_id (str): The user ID to check flags for.
Returns:
UserCreditBase: The appropriate credit model for the user
"""
if not settings.config.enable_credit:
return DisabledUserCredit()
if settings.config.enable_beta_monthly_credit:
return BetaUserCredit(settings.config.num_user_credits_refill)
# Check LaunchDarkly flag for payment pilot users
# Default to False (beta monthly credit behavior) to maintain current behavior
is_payment_enabled = await is_feature_enabled(
Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False
)
return UserCredit()
if is_payment_enabled:
# Payment enabled users get UserCredit (no monthly refills, enable payments)
return UserCredit()
else:
# Default behavior: users get beta monthly credits
return BetaUserCredit(settings.config.num_user_credits_refill)
def get_block_costs() -> dict[str, list["BlockCost"]]:
@@ -1090,7 +1278,8 @@ async def admin_get_user_history(
)
reason = metadata.get("reason", "No reason provided")
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
user_credit_model = await get_user_credit_model(tx.userId)
balance, _ = await user_credit_model._get_credits(tx.userId)
history.append(
UserTransaction(

View File

@@ -0,0 +1,172 @@
"""
Test ceiling balance functionality to ensure auto top-up limits work correctly.
This test was added to cover a previously untested code path that could lead to
incorrect balance capping behavior.
"""
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for ceiling tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_rejects_when_above_threshold(server: SpinTestServer):
"""Test that ceiling balance correctly rejects top-ups when balance is above threshold."""
credit_system = UserCredit()
user_id = f"ceiling-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 1000 ($10) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 1000
# Try to add 200 more with ceiling of 800 (should reject since 1000 > 800)
with pytest.raises(ValueError, match="You already have enough balance"):
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=800, # Ceiling lower than current balance
)
# Balance should remain unchanged
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1000, f"Balance should remain 1000, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_clamps_when_would_exceed(server: SpinTestServer):
"""Test that ceiling balance correctly clamps amounts that would exceed the ceiling."""
credit_system = UserCredit()
user_id = f"ceiling-clamp-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 500 ($5) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 800 more with ceiling of 1000 (should clamp to 1000, not reach 1300)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=800,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000, # Ceiling should clamp 500 + 800 = 1300 to 1000
)
# Balance should be clamped to ceiling
assert (
final_balance == 1000
), f"Balance should be clamped to 1000, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 1000
), f"Stored balance should be 1000, got {stored_balance}"
# Verify transaction shows the clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.TOP_UP},
order={"createdAt": "desc"},
)
# Should have 2 transactions: 500 + (500 to reach ceiling of 1000)
assert len(transactions) == 2
# The second transaction should show it only added 500, not 800
second_tx = transactions[0] # Most recent
assert second_tx.runningBalance == 1000
# The actual amount recorded could be 800 (what was requested) but balance was clamped
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_ceiling_balance_allows_when_under_threshold(server: SpinTestServer):
"""Test that ceiling balance allows top-ups when balance is under threshold."""
credit_system = UserCredit()
user_id = f"ceiling-under-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user balance of 300 ($3) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Add 200 more with ceiling of 1000 (should succeed: 300 + 200 = 500 < 1000)
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
ceiling_balance=1000,
)
# Balance should be exactly 500
assert final_balance == 500, f"Balance should be 500, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == 500
), f"Stored balance should be 500, got {stored_balance}"
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,737 @@
"""
Concurrency and atomicity tests for the credit system.
These tests ensure the credit system handles high-concurrency scenarios correctly
without race conditions, deadlocks, or inconsistent state.
"""
import asyncio
import random
from uuid import uuid4
import prisma.enums
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
# Test with both UserCredit and BetaUserCredit if needed
credit_system = UserCredit()
async def create_test_user(user_id: str) -> None:
"""Create a test user with initial balance."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
# Ensure UserBalance record exists
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_same_user(server: SpinTestServer):
"""Test multiple concurrent spends from the same user don't cause race conditions."""
user_id = f"concurrent-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Try to spend 10 x $1 concurrently
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{idx}",
reason=f"Concurrent spend {idx}",
),
)
except InsufficientBalanceError:
return None
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful spends
successful = [
r for r in results if r is not None and not isinstance(r, Exception)
]
failed = [r for r in results if isinstance(r, InsufficientBalanceError)]
# All 10 should succeed since we have exactly $10
assert len(successful) == 10, f"Expected 10 successful, got {len(successful)}"
assert len(failed) == 0, f"Expected 0 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
# Verify transaction history is consistent
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE}
)
assert (
len(transactions) == 10
), f"Expected 10 transactions, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_spends_insufficient_balance(server: SpinTestServer):
"""Test that concurrent spends correctly enforce balance limits."""
user_id = f"insufficient-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user limited balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=500,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "limited_balance"}),
)
# Try to spend 10 x $1 concurrently (but only have $5)
async def spend_one_dollar(idx: int):
try:
return await credit_system.spend_credits(
user_id,
100, # $1
UsageTransactionMetadata(
graph_exec_id=f"insufficient-{idx}",
reason=f"Insufficient spend {idx}",
),
)
except InsufficientBalanceError:
return "FAILED"
# Run 10 concurrent spends
results = await asyncio.gather(
*[spend_one_dollar(i) for i in range(10)], return_exceptions=True
)
# Count successful vs failed
successful = [
r
for r in results
if r not in ["FAILED", None] and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
# Exactly 5 should succeed, 5 should fail
assert len(successful) == 5, f"Expected 5 successful, got {len(successful)}"
assert len(failed) == 5, f"Expected 5 failures, got {len(failed)}"
# Final balance should be exactly 0
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_mixed_operations(server: SpinTestServer):
"""Test concurrent mix of spends, top-ups, and balance checks."""
user_id = f"mixed-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "initial_balance"}),
)
# Mix of operations
async def mixed_operations():
operations = []
# 5 spends of $1 each
for i in range(5):
operations.append(
credit_system.spend_credits(
user_id,
100,
UsageTransactionMetadata(reason=f"Mixed spend {i}"),
)
)
# 3 top-ups of $2 each using internal method
for i in range(3):
operations.append(
credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": f"concurrent_topup_{i}"}),
)
)
# 10 balance checks
for i in range(10):
operations.append(credit_system.get_credits(user_id))
return await asyncio.gather(*operations, return_exceptions=True)
results = await mixed_operations()
# Check no exceptions occurred
exceptions = [
r
for r in results
if isinstance(r, Exception) and not isinstance(r, InsufficientBalanceError)
]
assert len(exceptions) == 0, f"Unexpected exceptions: {exceptions}"
# Final balance should be: 1000 - 500 + 600 = 1100
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 1100, f"Expected balance 1100, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_race_condition_exact_balance(server: SpinTestServer):
"""Test spending exact balance amount concurrently doesn't go negative."""
user_id = f"exact-balance-{uuid4()}"
await create_test_user(user_id)
try:
# Give exact amount using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount"}),
)
# Try to spend $1 twice concurrently
async def spend_exact():
try:
return await credit_system.spend_credits(
user_id, 100, UsageTransactionMetadata(reason="Exact spend")
)
except InsufficientBalanceError:
return "FAILED"
# Both try to spend the full balance
result1, result2 = await asyncio.gather(spend_exact(), spend_exact())
# Exactly one should succeed
results = [result1, result2]
successful = [
r for r in results if r != "FAILED" and not isinstance(r, Exception)
]
failed = [r for r in results if r == "FAILED"]
assert len(successful) == 1, f"Expected 1 success, got {len(successful)}"
assert len(failed) == 1, f"Expected 1 failure, got {len(failed)}"
# Balance should be exactly 0, never negative
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 0, f"Expected balance 0, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_onboarding_reward_idempotency(server: SpinTestServer):
"""Test that onboarding rewards are idempotent (can't be claimed twice)."""
user_id = f"onboarding-test-{uuid4()}"
await create_test_user(user_id)
try:
# Use WELCOME step which is defined in the OnboardingStep enum
# Try to claim same reward multiple times concurrently
async def claim_reward():
try:
result = await credit_system.onboarding_reward(
user_id, 500, prisma.enums.OnboardingStep.WELCOME
)
return "SUCCESS" if result else "DUPLICATE"
except Exception as e:
print(f"Claim reward failed: {e}")
return "FAILED"
# Try 5 concurrent claims of the same reward
results = await asyncio.gather(*[claim_reward() for _ in range(5)])
# Count results
success_count = results.count("SUCCESS")
failed_count = results.count("FAILED")
# At least one should succeed, others should be duplicates
assert success_count >= 1, f"At least one claim should succeed, got {results}"
assert failed_count == 0, f"No claims should fail, got {results}"
# Check balance - should only have 500, not 2500
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 500, f"Expected balance 500, got {final_balance}"
# Check only one transaction exists
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.GRANT,
"transactionKey": f"REWARD-{user_id}-WELCOME",
}
)
assert (
len(transactions) == 1
), f"Expected 1 reward transaction, got {len(transactions)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_integer_overflow_protection(server: SpinTestServer):
"""Test that integer overflow is prevented by clamping to POSTGRES_INT_MAX."""
user_id = f"overflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Try to add amount that would overflow
max_int = POSTGRES_INT_MAX
# First, set balance near max
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": max_int - 100},
"update": {"balance": max_int - 100},
},
)
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
await credit_system._add_transaction(
user_id=user_id,
amount=200,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "overflow_protection"}),
)
# Balance should be clamped to max_int, not overflowed
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == max_int
), f"Balance should be clamped to {max_int}, got {final_balance}"
# Verify transaction was created with clamped amount
transactions = await CreditTransaction.prisma().find_many(
where={
"userId": user_id,
"type": prisma.enums.CreditTransactionType.TOP_UP,
},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Transaction should be created"
assert (
transactions[0].runningBalance == max_int
), "Transaction should show clamped balance"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_high_concurrency_stress(server: SpinTestServer):
"""Stress test with many concurrent operations."""
user_id = f"stress-test-{uuid4()}"
await create_test_user(user_id)
try:
# Initial balance using internal method (bypasses Stripe)
initial_balance = 10000 # $100
await credit_system._add_transaction(
user_id=user_id,
amount=initial_balance,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "stress_test_balance"}),
)
# Run many concurrent operations
async def random_operation(idx: int):
operation = random.choice(["spend", "check"])
if operation == "spend":
amount = random.randint(1, 50) # $0.01 to $0.50
try:
return (
"spend",
amount,
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(reason=f"Stress {idx}"),
),
)
except InsufficientBalanceError:
return ("spend_failed", amount, None)
else:
balance = await credit_system.get_credits(user_id)
return ("check", 0, balance)
# Run 100 concurrent operations
results = await asyncio.gather(
*[random_operation(i) for i in range(100)], return_exceptions=True
)
# Calculate expected final balance
total_spent = sum(
r[1]
for r in results
if not isinstance(r, Exception) and isinstance(r, tuple) and r[0] == "spend"
)
expected_balance = initial_balance - total_spent
# Verify final balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == expected_balance
), f"Expected {expected_balance}, got {final_balance}"
assert final_balance >= 0, "Balance went negative!"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_multiple_spends_sufficient_balance(server: SpinTestServer):
"""Test multiple concurrent spends when there's sufficient balance for all."""
user_id = f"multi-spend-test-{uuid4()}"
await create_test_user(user_id)
try:
# Give user 150 balance ($1.50) using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=150,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "sufficient_balance"}),
)
# Track individual timing to see serialization
timings = {}
async def spend_with_detailed_timing(amount: int, label: str):
start = asyncio.get_event_loop().time()
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent spend {label}",
),
)
end = asyncio.get_event_loop().time()
timings[label] = {"start": start, "end": end, "duration": end - start}
return f"{label}-SUCCESS"
except Exception as e:
end = asyncio.get_event_loop().time()
timings[label] = {
"start": start,
"end": end,
"duration": end - start,
"error": str(e),
}
return f"{label}-FAILED: {e}"
# Run concurrent spends: 10, 20, 30 (total 60, well under 150)
overall_start = asyncio.get_event_loop().time()
results = await asyncio.gather(
spend_with_detailed_timing(10, "spend-10"),
spend_with_detailed_timing(20, "spend-20"),
spend_with_detailed_timing(30, "spend-30"),
return_exceptions=True,
)
overall_end = asyncio.get_event_loop().time()
print(f"Results: {results}")
print(f"Overall duration: {overall_end - overall_start:.4f}s")
# Analyze timing to detect serialization vs true concurrency
print("\nTiming analysis:")
for label, timing in timings.items():
print(
f" {label}: started at {timing['start']:.4f}, ended at {timing['end']:.4f}, duration {timing['duration']:.4f}s"
)
# Check if operations overlapped (true concurrency) or were serialized
sorted_timings = sorted(timings.items(), key=lambda x: x[1]["start"])
print("\nExecution order by start time:")
for i, (label, timing) in enumerate(sorted_timings):
print(f" {i+1}. {label}: {timing['start']:.4f} -> {timing['end']:.4f}")
# Check for overlap (true concurrency) vs serialization
overlaps = []
for i in range(len(sorted_timings) - 1):
current = sorted_timings[i]
next_op = sorted_timings[i + 1]
if current[1]["end"] > next_op[1]["start"]:
overlaps.append(f"{current[0]} overlaps with {next_op[0]}")
if overlaps:
print(f"✅ TRUE CONCURRENCY detected: {overlaps}")
else:
print("🔒 SERIALIZATION detected: No overlapping execution times")
# Check final balance
final_balance = await credit_system.get_credits(user_id)
print(f"Final balance: {final_balance}")
# Count successes/failures
successful = [r for r in results if "SUCCESS" in str(r)]
failed = [r for r in results if "FAILED" in str(r)]
print(f"Successful: {len(successful)}, Failed: {len(failed)}")
# All should succeed since 150 - (10 + 20 + 30) = 90 > 0
assert (
len(successful) == 3
), f"Expected all 3 to succeed, got {len(successful)} successes: {results}"
assert final_balance == 90, f"Expected balance 90, got {final_balance}"
# Check transaction timestamps to confirm database-level serialization
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": prisma.enums.CreditTransactionType.USAGE},
order={"createdAt": "asc"},
)
print("\nDatabase transaction order (by createdAt):")
for i, tx in enumerate(transactions):
print(
f" {i+1}. Amount {tx.amount}, Running balance: {tx.runningBalance}, Created: {tx.createdAt}"
)
# Verify running balances are chronologically consistent (ordered by createdAt)
actual_balances = [
tx.runningBalance for tx in transactions if tx.runningBalance is not None
]
print(f"Running balances: {actual_balances}")
# The balances should be valid intermediate states regardless of execution order
# Starting balance: 150, spending 10+20+30=60, so final should be 90
# The intermediate balances depend on execution order but should all be valid
expected_possible_balances = {
# If order is 10, 20, 30: [140, 120, 90]
# If order is 10, 30, 20: [140, 110, 90]
# If order is 20, 10, 30: [130, 120, 90]
# If order is 20, 30, 10: [130, 100, 90]
# If order is 30, 10, 20: [120, 110, 90]
# If order is 30, 20, 10: [120, 100, 90]
90,
100,
110,
120,
130,
140, # All possible intermediate balances
}
# Verify all balances are valid intermediate states
for balance in actual_balances:
assert (
balance in expected_possible_balances
), f"Invalid balance {balance}, expected one of {expected_possible_balances}"
# Final balance should always be 90 (150 - 60)
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
# The final transaction should always have balance 90
# The other transactions should have valid intermediate balances
assert (
90 in actual_balances
), f"Final balance 90 should be in actual_balances: {actual_balances}"
# All balances should be >= 90 (the final state)
assert all(
balance >= 90 for balance in actual_balances
), f"All balances should be >= 90, got {actual_balances}"
# CRITICAL: Transactions are atomic but can complete in any order
# What matters is that all running balances are valid intermediate states
# Each balance should be between 90 (final) and 140 (after first transaction)
for balance in actual_balances:
assert (
90 <= balance <= 140
), f"Balance {balance} is outside valid range [90, 140]"
# Final balance (minimum) should always be 90
assert (
min(actual_balances) == 90
), f"Final balance should be 90, got {min(actual_balances)}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_prove_database_locking_behavior(server: SpinTestServer):
"""Definitively prove whether database locking causes waiting vs failures."""
user_id = f"locking-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set balance to exact amount that can handle all spends using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=60, # Exactly 10+20+30
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "exact_amount_test"}),
)
async def spend_with_precise_timing(amount: int, label: str):
request_start = asyncio.get_event_loop().time()
db_operation_start = asyncio.get_event_loop().time()
try:
# Add a small delay to increase chance of true concurrency
await asyncio.sleep(0.001)
db_operation_start = asyncio.get_event_loop().time()
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"locking-{label}",
reason=f"Locking test {label}",
),
)
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "SUCCESS",
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
except Exception as e:
db_operation_end = asyncio.get_event_loop().time()
return {
"label": label,
"status": "FAILED",
"error": str(e),
"request_start": request_start,
"db_start": db_operation_start,
"db_end": db_operation_end,
"db_duration": db_operation_end - db_operation_start,
}
# Launch all requests simultaneously
results = await asyncio.gather(
spend_with_precise_timing(10, "A"),
spend_with_precise_timing(20, "B"),
spend_with_precise_timing(30, "C"),
return_exceptions=True,
)
print("\n🔍 LOCKING BEHAVIOR ANALYSIS:")
print("=" * 50)
successful = [
r for r in results if isinstance(r, dict) and r.get("status") == "SUCCESS"
]
failed = [
r for r in results if isinstance(r, dict) and r.get("status") == "FAILED"
]
print(f"✅ Successful operations: {len(successful)}")
print(f"❌ Failed operations: {len(failed)}")
if len(failed) > 0:
print(
"\n🚫 CONCURRENT FAILURES - Some requests failed due to insufficient balance:"
)
for result in failed:
if isinstance(result, dict):
print(
f" {result['label']}: {result.get('error', 'Unknown error')}"
)
if len(successful) == 3:
print(
"\n🔒 SERIALIZATION CONFIRMED - All requests succeeded, indicating they were queued:"
)
# Sort by actual execution time to see order
dict_results = [r for r in results if isinstance(r, dict)]
sorted_results = sorted(dict_results, key=lambda x: x["db_start"])
for i, result in enumerate(sorted_results):
print(
f" {i+1}. {result['label']}: DB operation took {result['db_duration']:.4f}s"
)
# Check if any operations overlapped at the database level
print("\n⏱️ Database operation timeline:")
for result in sorted_results:
print(
f" {result['label']}: {result['db_start']:.4f} -> {result['db_end']:.4f}"
)
# Verify final state
final_balance = await credit_system.get_credits(user_id)
print(f"\n💰 Final balance: {final_balance}")
if len(successful) == 3:
assert (
final_balance == 0
), f"If all succeeded, balance should be 0, got {final_balance}"
print(
"✅ CONCLUSION: Database row locking causes requests to WAIT and execute serially"
)
else:
print(
"❌ CONCLUSION: Some requests failed, indicating different concurrency behavior"
)
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,277 @@
"""
Integration tests for credit system to catch SQL enum casting issues.
These tests run actual database operations to ensure SQL queries work correctly,
which would have caught the CreditTransactionType enum casting bug.
"""
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import (
AutoTopUpConfig,
BetaUserCredit,
UsageTransactionMetadata,
get_auto_top_up,
set_auto_top_up,
)
from backend.util.json import SafeJson
@pytest.fixture
async def cleanup_test_user():
"""Clean up test user data before and after tests."""
import uuid
user_id = str(uuid.uuid4()) # Use unique user ID for each test
# Create the user first
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"topUpConfig": SafeJson({}),
"timezone": "UTC",
}
)
except Exception:
# User might already exist, that's fine
pass
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
# Clear auto-top-up config before deleting user
await User.prisma().update(
where={"id": user_id}, data={"topUpConfig": SafeJson({})}
)
await User.prisma().delete(where={"id": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_credit_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test to verify CreditTransactionType enum casting works in SQL queries.
This test would have caught the enum casting bug where PostgreSQL expected
platform."CreditTransactionType" but got "CreditTransactionType".
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Test each transaction type to ensure enum casting works
test_cases = [
(CreditTransactionType.TOP_UP, 100, "Test top-up"),
(CreditTransactionType.USAGE, -50, "Test usage"),
(CreditTransactionType.GRANT, 200, "Test grant"),
(CreditTransactionType.REFUND, -25, "Test refund"),
(CreditTransactionType.CARD_CHECK, 0, "Test card check"),
]
for transaction_type, amount, reason in test_cases:
metadata = SafeJson({"reason": reason, "test": "enum_casting"})
# This call would fail with enum casting error before the fix
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=amount,
transaction_type=transaction_type,
metadata=metadata,
is_active=True,
)
# Verify transaction was created with correct type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.type == transaction_type
assert transaction.amount == amount
assert transaction.metadata is not None
# Verify metadata content
assert transaction.metadata["reason"] == reason
assert transaction.metadata["test"] == "enum_casting"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_integration(cleanup_test_user, monkeypatch):
"""
Integration test for auto-top-up functionality that triggers enum casting.
This tests the complete auto-top-up flow which involves SQL queries with
CreditTransactionType enums, ensuring enum casting works end-to-end.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# First add some initial credits so we can test the configuration and subsequent behavior
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50, # Below threshold that we'll set
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial credits before auto top-up config"}),
)
assert balance == 50
# Configure auto top-up with threshold above current balance
config = AutoTopUpConfig(threshold=100, amount=500)
await set_auto_top_up(user_id, config)
# Verify configuration was saved but no immediate top-up occurred
current_balance = await credit_system.get_credits(user_id)
assert current_balance == 50 # Balance should be unchanged
# Simulate spending credits that would trigger auto top-up
# This involves multiple SQL operations with enum casting
try:
metadata = UsageTransactionMetadata(reason="Test spend to trigger auto top-up")
await credit_system.spend_credits(user_id=user_id, cost=10, metadata=metadata)
# The auto top-up mechanism should have been triggered
# Verify the transaction types were handled correctly
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should have at least: GRANT (initial), USAGE (spend), and TOP_UP (auto top-up)
assert len(transactions) >= 3
# Verify different transaction types exist and enum casting worked
transaction_types = {t.type for t in transactions}
assert CreditTransactionType.GRANT in transaction_types
assert CreditTransactionType.USAGE in transaction_types
assert (
CreditTransactionType.TOP_UP in transaction_types
) # Auto top-up should have triggered
except Exception as e:
# If this fails with enum casting error, the test successfully caught the bug
if "CreditTransactionType" in str(e) and (
"cast" in str(e).lower() or "type" in str(e).lower()
):
pytest.fail(f"Enum casting error detected: {e}")
else:
# Re-raise other unexpected errors
raise
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_enum_casting_integration(cleanup_test_user):
"""
Integration test for _enable_transaction with enum casting.
Tests the scenario where inactive transactions are enabled, which also
involves SQL queries with CreditTransactionType enum casting.
"""
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=100,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"reason": "Inactive transaction test"}),
is_active=False, # Create as inactive
)
# Balance should be 0 since transaction is inactive
assert balance == 0
# Enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "test_payment",
"activation_reason": "Integration test activation",
}
)
# This would fail with enum casting error before the fix
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 100
# Verify transaction was properly enabled with correct enum type
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
assert transaction.type == CreditTransactionType.TOP_UP
assert transaction.runningBalance == 100
# Verify metadata was updated
assert transaction.metadata is not None
assert transaction.metadata["payment_method"] == "test_payment"
assert transaction.metadata["activation_reason"] == "Integration test activation"
@pytest.mark.asyncio(loop_scope="session")
async def test_auto_top_up_configuration_storage(cleanup_test_user, monkeypatch):
"""
Test that auto-top-up configuration is properly stored and retrieved.
The immediate top-up logic is handled by the API routes, not the core
set_auto_top_up function. This test verifies the configuration is correctly
saved and can be retrieved.
"""
# Enable credits for this test
from backend.data.credit import settings
monkeypatch.setattr(settings.config, "enable_credit", True)
monkeypatch.setattr(settings.config, "enable_beta_monthly_credit", True)
monkeypatch.setattr(settings.config, "num_user_credits_refill", 1000)
user_id = cleanup_test_user
credit_system = BetaUserCredit(1000)
# Set initial balance
balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=50,
transaction_type=CreditTransactionType.GRANT,
metadata=SafeJson({"reason": "Initial balance for config test"}),
)
assert balance == 50
# Configure auto top-up
config = AutoTopUpConfig(threshold=100, amount=200)
await set_auto_top_up(user_id, config)
# Verify the configuration was saved
retrieved_config = await get_auto_top_up(user_id)
assert retrieved_config.threshold == config.threshold
assert retrieved_config.amount == config.amount
# Verify balance is unchanged (no immediate top-up from set_auto_top_up)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 50 # Should be unchanged
# Verify no immediate auto-top-up transaction was created by set_auto_top_up
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id}, order={"createdAt": "desc"}
)
# Should only have the initial GRANT transaction
assert len(transactions) == 1
assert transactions[0].type == CreditTransactionType.GRANT

View File

@@ -0,0 +1,141 @@
"""
Tests for credit system metadata handling to ensure JSON casting works correctly.
This test verifies that metadata parameters are properly serialized when passed
to raw SQL queries with JSONB columns.
"""
# type: ignore
from typing import Any
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction, UserBalance
from backend.data.credit import BetaUserCredit
from backend.data.user import DEFAULT_USER_ID
from backend.util.json import SafeJson
@pytest.fixture
async def setup_test_user():
"""Setup test user and cleanup after test."""
user_id = DEFAULT_USER_ID
# Cleanup before test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
yield user_id
# Cleanup after test
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
@pytest.mark.asyncio(loop_scope="session")
async def test_metadata_json_serialization(setup_test_user):
"""Test that metadata is properly serialized for JSONB column in raw SQL."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# Test with complex metadata that would fail if not properly serialized
complex_metadata = SafeJson(
{
"graph_exec_id": "test-12345",
"reason": "Testing metadata serialization",
"nested_data": {
"key1": "value1",
"key2": ["array", "of", "values"],
"key3": {"deeply": {"nested": "object"}},
},
"special_chars": "Testing 'quotes' and \"double quotes\" and unicode: 🚀",
}
)
# This should work without throwing a JSONB casting error
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=500, # $5 top-up
transaction_type=CreditTransactionType.TOP_UP,
metadata=complex_metadata,
is_active=True,
)
# Verify the transaction was created successfully
assert balance == 500
# Verify the metadata was stored correctly in the database
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.metadata is not None
# Verify the metadata contains our complex data
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["graph_exec_id"] == "test-12345"
assert metadata_dict["reason"] == "Testing metadata serialization"
assert metadata_dict["nested_data"]["key1"] == "value1"
assert metadata_dict["nested_data"]["key3"]["deeply"]["nested"] == "object"
assert (
metadata_dict["special_chars"]
== "Testing 'quotes' and \"double quotes\" and unicode: 🚀"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_enable_transaction_metadata_serialization(setup_test_user):
"""Test that _enable_transaction also handles metadata JSON serialization correctly."""
user_id = setup_test_user
credit_system = BetaUserCredit(1000)
# First create an inactive transaction
balance, tx_key = await credit_system._add_transaction(
user_id=user_id,
amount=300,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"initial": "inactive_transaction"}),
is_active=False, # Create as inactive
)
# Initial balance should be 0 because transaction is inactive
assert balance == 0
# Now enable the transaction with new metadata
enable_metadata = SafeJson(
{
"payment_method": "stripe",
"payment_intent": "pi_test_12345",
"activation_reason": "Payment confirmed",
"complex_data": {"array": [1, 2, 3], "boolean": True, "null_value": None},
}
)
# This should work without JSONB casting errors
final_balance = await credit_system._enable_transaction(
transaction_key=tx_key,
user_id=user_id,
metadata=enable_metadata,
)
# Now balance should reflect the activated transaction
assert final_balance == 300
# Verify the metadata was updated correctly
transaction = await CreditTransaction.prisma().find_first(
where={"userId": user_id, "transactionKey": tx_key}
)
assert transaction is not None
assert transaction.isActive is True
# Verify the metadata was updated with enable_metadata
metadata_dict: dict[str, Any] = dict(transaction.metadata) # type: ignore
assert metadata_dict["payment_method"] == "stripe"
assert metadata_dict["payment_intent"] == "pi_test_12345"
assert metadata_dict["complex_data"]["array"] == [1, 2, 3]
assert metadata_dict["complex_data"]["boolean"] is True
assert metadata_dict["complex_data"]["null_value"] is None

View File

@@ -0,0 +1,372 @@
"""
Tests for credit system refund and dispute operations.
These tests ensure that refund operations (deduct_credits, handle_dispute)
are atomic and maintain data consistency.
"""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
import stripe
from prisma.enums import CreditTransactionType
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
from backend.data.credit import UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
credit_system = UserCredit()
# Test user ID for refund tests
REFUND_TEST_USER_ID = "refund-test-user"
async def setup_test_user_with_topup():
"""Create a test user with initial balance and a top-up transaction."""
# Clean up any existing data
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
# Create user
await User.prisma().create(
data={
"id": REFUND_TEST_USER_ID,
"email": f"{REFUND_TEST_USER_ID}@example.com",
"name": "Refund Test User",
}
)
# Create user balance
await UserBalance.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"balance": 1000, # $10
}
)
# Create a top-up transaction that can be refunded
topup_tx = await CreditTransaction.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 1000,
"type": CreditTransactionType.TOP_UP,
"transactionKey": "pi_test_12345",
"runningBalance": 1000,
"isActive": True,
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
}
)
return topup_tx
async def cleanup_test_user():
"""Clean up test data."""
await CreditRefundRequest.prisma().delete_many(
where={"userId": REFUND_TEST_USER_ID}
)
await CreditTransaction.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await UserBalance.prisma().delete_many(where={"userId": REFUND_TEST_USER_ID})
await User.prisma().delete_many(where={"id": REFUND_TEST_USER_ID})
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_atomic(server: SpinTestServer):
"""Test that deduct_credits is atomic and creates transaction correctly."""
topup_tx = await setup_test_user_with_topup()
try:
# Create a mock refund object
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_123"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 500 # Refund $5 of the $10 top-up
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Create refund request record (simulating webhook flow)
await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 500,
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
"reason": "Test refund",
}
)
# Call deduct_credits
await credit_system.deduct_credits(refund)
# Verify the user's balance was deducted
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 500
), f"Expected balance 500, got {user_balance.balance}"
# Verify refund transaction was created
refund_tx = await CreditTransaction.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
"transactionKey": refund.id,
}
)
assert refund_tx is not None
assert refund_tx.amount == -500
assert refund_tx.runningBalance == 500
assert refund_tx.isActive
# Verify refund request was updated
refund_request = await CreditRefundRequest.prisma().find_first(
where={
"userId": REFUND_TEST_USER_ID,
"transactionKey": topup_tx.transactionKey,
}
)
assert refund_request is not None
assert (
refund_request.result
== "The refund request has been approved, the amount will be credited back to your account."
)
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_deduct_credits_user_not_found(server: SpinTestServer):
"""Test that deduct_credits raises error if transaction not found (which means user doesn't exist)."""
# Create a mock refund object that references a non-existent payment intent
refund = MagicMock(spec=stripe.Refund)
refund.id = "re_test_refund_nonexistent"
refund.payment_intent = "pi_test_nonexistent" # This payment intent doesn't exist
refund.amount = 500
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
# Should raise error for missing transaction
with pytest.raises(Exception): # Should raise NotFoundError for missing transaction
await credit_system.deduct_credits(refund)
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_sufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has sufficient balance (dispute gets closed)."""
topup_tx = await setup_test_user_with_topup()
try:
# Mock settings to have a low tolerance threshold
mock_settings.config.refund_credit_tolerance_threshold = 0
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Create a mock dispute object for small amount (user has 1000, disputing 100)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_123"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 100 # Small dispute amount
dispute.status = "pending"
dispute.reason = "fraudulent"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was called (since user has sufficient balance)
dispute.close.assert_called_once()
# Verify no stripe evidence was added since dispute was closed
mock_stripe_modify.assert_not_called()
# Verify the user's balance was NOT deducted (dispute was closed)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert (
user_balance.balance == 1000
), f"Balance should remain 1000, got {user_balance.balance}"
finally:
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
@patch("backend.data.credit.settings")
@patch("stripe.Dispute.modify")
@patch("backend.data.credit.get_user_by_id")
async def test_handle_dispute_with_insufficient_balance(
mock_get_user, mock_stripe_modify, mock_settings, server: SpinTestServer
):
"""Test handling dispute when user has insufficient balance (evidence gets added)."""
topup_tx = await setup_test_user_with_topup()
# Save original method for restoration before any try blocks
original_get_history = credit_system.get_transaction_history
try:
# Mock settings to have a high tolerance threshold so dispute isn't closed
mock_settings.config.refund_credit_tolerance_threshold = 2000
# Mock the user lookup
mock_user = MagicMock()
mock_user.email = f"{REFUND_TEST_USER_ID}@example.com"
mock_get_user.return_value = mock_user
# Mock the transaction history method to return an async result
from unittest.mock import AsyncMock
mock_history = MagicMock()
mock_history.transactions = []
credit_system.get_transaction_history = AsyncMock(return_value=mock_history)
# Create a mock dispute object for full amount (user has 1000, disputing 1000)
dispute = MagicMock(spec=stripe.Dispute)
dispute.id = "dp_test_dispute_pending"
dispute.payment_intent = topup_tx.transactionKey
dispute.amount = 1000
dispute.status = "warning_needs_response"
dispute.created = int(datetime.now(timezone.utc).timestamp())
# Mock the close method to prevent real API calls
dispute.close = MagicMock()
# Handle the dispute (evidence should be added)
await credit_system.handle_dispute(dispute)
# Verify dispute.close() was NOT called (insufficient balance after tolerance)
dispute.close.assert_not_called()
# Verify stripe evidence was added since dispute wasn't closed
mock_stripe_modify.assert_called_once()
# Verify the user's balance was NOT deducted (handle_dispute doesn't deduct credits)
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
assert user_balance.balance == 1000, "Balance should remain unchanged"
finally:
credit_system.get_transaction_history = original_get_history
await cleanup_test_user()
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_refunds(server: SpinTestServer):
"""Test that concurrent refunds are handled atomically."""
import asyncio
topup_tx = await setup_test_user_with_topup()
try:
# Create multiple refund requests
refund_requests = []
for i in range(5):
req = await CreditRefundRequest.prisma().create(
data={
"userId": REFUND_TEST_USER_ID,
"amount": 100, # $1 each
"transactionKey": topup_tx.transactionKey,
"reason": f"Test refund {i}",
}
)
refund_requests.append(req)
# Create refund tasks to run concurrently
async def process_refund(index: int):
refund = MagicMock(spec=stripe.Refund)
refund.id = f"re_test_concurrent_{index}"
refund.payment_intent = topup_tx.transactionKey
refund.amount = 100 # $1 refund
refund.status = "succeeded"
refund.reason = "requested_by_customer"
refund.created = int(datetime.now(timezone.utc).timestamp())
try:
await credit_system.deduct_credits(refund)
return "success"
except Exception as e:
return f"error: {e}"
# Run refunds concurrently
results = await asyncio.gather(
*[process_refund(i) for i in range(5)], return_exceptions=True
)
# All should succeed
assert all(r == "success" for r in results), f"Some refunds failed: {results}"
# Verify final balance - with non-atomic implementation, this will demonstrate race condition
# EXPECTED BEHAVIOR: Due to race conditions, not all refunds will be properly processed
# The balance will be incorrect (higher than expected) showing lost updates
user_balance = await UserBalance.prisma().find_unique(
where={"userId": REFUND_TEST_USER_ID}
)
assert user_balance is not None
# With atomic implementation, this should be 500 (1000 - 5*100)
# With current non-atomic implementation, this will likely be wrong due to race conditions
print(f"DEBUG: Final balance = {user_balance.balance}, expected = 500")
# With atomic implementation, all 5 refunds should process correctly
assert (
user_balance.balance == 500
), f"Expected balance 500 after 5 refunds of 100 each, got {user_balance.balance}"
# Verify all refund transactions exist
refund_txs = await CreditTransaction.prisma().find_many(
where={
"userId": REFUND_TEST_USER_ID,
"type": CreditTransactionType.REFUND,
}
)
assert (
len(refund_txs) == 5
), f"Expected 5 refund transactions, got {len(refund_txs)}"
running_balances: set[int] = {
tx.runningBalance for tx in refund_txs if tx.runningBalance is not None
}
# Verify all balances are valid intermediate states
for balance in running_balances:
assert (
500 <= balance <= 1000
), f"Invalid balance {balance}, should be between 500 and 1000"
# Final balance should be present
assert (
500 in running_balances
), f"Final balance 500 should be in {running_balances}"
# All balances should be unique and form a valid sequence
sorted_balances = sorted(running_balances, reverse=True)
assert (
len(sorted_balances) == 5
), f"Expected 5 unique balances, got {len(sorted_balances)}"
finally:
await cleanup_test_user()

View File

@@ -1,8 +1,8 @@
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import pytest
from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction
from prisma.models import CreditTransaction, UserBalance
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
@@ -19,14 +19,24 @@ user_credit = BetaUserCredit(REFILL_VALUE)
async def disable_test_user_transactions():
await CreditTransaction.prisma().delete_many(where={"userId": DEFAULT_USER_ID})
# Also reset the balance to 0 and set updatedAt to old date to trigger monthly refill
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
"update": {"balance": 0, "updatedAt": old_date},
},
)
async def top_up(amount: int):
await user_credit._add_transaction(
balance, _ = await user_credit._add_transaction(
DEFAULT_USER_ID,
amount,
CreditTransactionType.TOP_UP,
)
return balance
async def spend_credits(entry: NodeExecutionEntry) -> int:
@@ -111,29 +121,90 @@ async def test_block_credit_top_up(server: SpinTestServer):
@pytest.mark.asyncio(loop_scope="session")
async def test_block_credit_reset(server: SpinTestServer):
"""Test that BetaUserCredit provides monthly refills correctly."""
await disable_test_user_transactions()
month1 = 1
month2 = 2
# set the calendar to month 2 but use current time from now
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
# Save original time_now function for restoration
original_time_now = user_credit.time_now
# Month 1 result should only affect month 1
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month1, day=1
)
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
await top_up(100)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
try:
# Test month 1 behavior
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
user_credit.time_now = lambda: month1
# Month 2 balance is unaffected
user_credit.time_now = lambda: datetime.now(timezone.utc).replace(
month=month2, day=1
)
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
# First call in month 1 should trigger refill
balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE # Should get 1000 credits
# Manually create a transaction with month 1 timestamp to establish history
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": 100,
"type": CreditTransactionType.TOP_UP,
"runningBalance": 1100,
"isActive": True,
"createdAt": month1, # Set specific timestamp
}
)
# Update user balance to match
await UserBalance.prisma().upsert(
where={"userId": DEFAULT_USER_ID},
data={
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
"update": {"balance": 1100},
},
)
# Now test month 2 behavior
month2 = datetime.now(timezone.utc).replace(month=2, day=1)
user_credit.time_now = lambda: month2
# In month 2, since balance (1100) > refill (1000), no refill should happen
month2_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month2_balance == 1100 # Balance persists, no reset
# Now test the refill behavior when balance is low
# Set balance below refill threshold
await UserBalance.prisma().update(
where={"userId": DEFAULT_USER_ID}, data={"balance": 400}
)
# Create a month 2 transaction to update the last transaction time
await CreditTransaction.prisma().create(
data={
"userId": DEFAULT_USER_ID,
"amount": -700, # Spent 700 to get to 400
"type": CreditTransactionType.USAGE,
"runningBalance": 400,
"isActive": True,
"createdAt": month2,
}
)
# Move to month 3
month3 = datetime.now(timezone.utc).replace(month=3, day=1)
user_credit.time_now = lambda: month3
# Should get refilled since balance (400) < refill value (1000)
month3_balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert month3_balance == REFILL_VALUE # Should be refilled to 1000
# Verify the refill transaction was created
refill_tx = await CreditTransaction.prisma().find_first(
where={
"userId": DEFAULT_USER_ID,
"type": CreditTransactionType.GRANT,
"transactionKey": {"contains": "MONTHLY-CREDIT-TOP-UP"},
},
order={"createdAt": "desc"},
)
assert refill_tx is not None, "Monthly refill transaction should be created"
assert refill_tx.amount == 600, "Refill should be 600 (1000 - 400)"
finally:
# Restore original time_now function
user_credit.time_now = original_time_now
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -0,0 +1,361 @@
"""
Test underflow protection for cumulative refunds and negative transactions.
This test ensures that when multiple large refunds are processed, the user balance
doesn't underflow below POSTGRES_INT_MIN, which could cause integer wraparound issues.
"""
import asyncio
from uuid import uuid4
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for underflow tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
)
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their transactions."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_debug_underflow_step_by_step(server: SpinTestServer):
"""Debug underflow behavior step by step."""
credit_system = UserCredit()
user_id = f"debug-underflow-{uuid4()}"
await create_test_user(user_id)
try:
print(f"POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Test 1: Set up balance close to underflow threshold
print("\n=== Test 1: Setting up balance close to underflow threshold ===")
# First, manually set balance to a value very close to POSTGRES_INT_MIN
# We'll set it to POSTGRES_INT_MIN + 100, then try to subtract 200
# This should trigger underflow protection: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
initial_balance_target = POSTGRES_INT_MIN + 100
# Use direct database update to set the balance close to underflow
from prisma.models import UserBalance
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance_target},
"update": {"balance": initial_balance_target},
},
)
current_balance = await credit_system.get_credits(user_id)
print(f"Set balance to: {current_balance}")
assert current_balance == initial_balance_target
# Test 2: Apply amount that should cause underflow
print("\n=== Test 2: Testing underflow protection ===")
test_amount = (
-200
) # This should cause underflow: (POSTGRES_INT_MIN + 100) + (-200) = POSTGRES_INT_MIN - 100
expected_without_protection = current_balance + test_amount
print(f"Current balance: {current_balance}")
print(f"Test amount: {test_amount}")
print(f"Without protection would be: {expected_without_protection}")
print(f"Should be clamped to POSTGRES_INT_MIN: {POSTGRES_INT_MIN}")
# Apply the amount that should trigger underflow protection
balance_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=test_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"Actual result: {balance_result}")
# Check if underflow protection worked
assert (
balance_result == POSTGRES_INT_MIN
), f"Expected underflow protection to clamp balance to {POSTGRES_INT_MIN}, got {balance_result}"
# Test 3: Edge case - exactly at POSTGRES_INT_MIN
print("\n=== Test 3: Testing exact POSTGRES_INT_MIN boundary ===")
# Set balance to exactly POSTGRES_INT_MIN
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
"update": {"balance": POSTGRES_INT_MIN},
},
)
edge_balance = await credit_system.get_credits(user_id)
print(f"Balance set to exactly POSTGRES_INT_MIN: {edge_balance}")
# Try to subtract 1 - should stay at POSTGRES_INT_MIN
edge_result, _ = await credit_system._add_transaction(
user_id=user_id,
amount=-1,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
print(f"After subtracting 1: {edge_result}")
assert (
edge_result == POSTGRES_INT_MIN
), f"Expected balance to remain clamped at {POSTGRES_INT_MIN}, got {edge_result}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_underflow_protection_large_refunds(server: SpinTestServer):
"""Test that large cumulative refunds don't cause integer underflow."""
credit_system = UserCredit()
user_id = f"underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold to test the protection
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
# This should trigger underflow protection
from prisma.models import UserBalance
test_balance = POSTGRES_INT_MIN + 1000
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": test_balance},
"update": {"balance": test_balance},
},
)
current_balance = await credit_system.get_credits(user_id)
assert current_balance == test_balance
# Try to deduct amount that would cause underflow: test_balance + (-2000) = POSTGRES_INT_MIN - 1000
underflow_amount = -2000
expected_without_protection = (
current_balance + underflow_amount
) # Should be POSTGRES_INT_MIN - 1000
# Use _add_transaction directly with amount that would cause underflow
final_balance, _ = await credit_system._add_transaction(
user_id=user_id,
amount=underflow_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False, # Allow going negative for refunds
)
# Balance should be clamped to POSTGRES_INT_MIN, not the calculated underflow value
assert (
final_balance == POSTGRES_INT_MIN
), f"Balance should be clamped to {POSTGRES_INT_MIN}, got {final_balance}"
assert (
final_balance > expected_without_protection
), f"Balance should be greater than underflow result {expected_without_protection}, got {final_balance}"
# Verify with get_credits too
stored_balance = await credit_system.get_credits(user_id)
assert (
stored_balance == POSTGRES_INT_MIN
), f"Stored balance should be {POSTGRES_INT_MIN}, got {stored_balance}"
# Verify transaction was created with the underflow-protected balance
transactions = await CreditTransaction.prisma().find_many(
where={"userId": user_id, "type": CreditTransactionType.REFUND},
order={"createdAt": "desc"},
)
assert len(transactions) > 0, "Refund transaction should be created"
assert (
transactions[0].runningBalance == POSTGRES_INT_MIN
), f"Transaction should show clamped balance {POSTGRES_INT_MIN}, got {transactions[0].runningBalance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServer):
"""Test that multiple large refunds applied sequentially don't cause underflow."""
credit_system = UserCredit()
user_id = f"cumulative-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
# Apply multiple refunds that would cumulatively underflow
refund_amount = -300 # Each refund that would cause underflow when cumulative
# First refund: (POSTGRES_INT_MIN + 500) + (-300) = POSTGRES_INT_MIN + 200 (still above minimum)
balance_1, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be above minimum for first refund
expected_balance_1 = (
initial_balance + refund_amount
) # Should be POSTGRES_INT_MIN + 200
assert (
balance_1 == expected_balance_1
), f"First refund should result in {expected_balance_1}, got {balance_1}"
assert (
balance_1 >= POSTGRES_INT_MIN
), f"First refund should not go below {POSTGRES_INT_MIN}, got {balance_1}"
# Second refund: (POSTGRES_INT_MIN + 200) + (-300) = POSTGRES_INT_MIN - 100 (would underflow)
balance_2, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should be clamped to minimum due to underflow protection
assert (
balance_2 == POSTGRES_INT_MIN
), f"Second refund should be clamped to {POSTGRES_INT_MIN}, got {balance_2}"
# Third refund: Should stay at minimum
balance_3, _ = await credit_system._add_transaction(
user_id=user_id,
amount=refund_amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
# Should still be at minimum
assert (
balance_3 == POSTGRES_INT_MIN
), f"Third refund should stay at {POSTGRES_INT_MIN}, got {balance_3}"
# Final balance check
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == POSTGRES_INT_MIN
), f"Final balance should be {POSTGRES_INT_MIN}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
"""Test that concurrent large refunds don't cause race condition underflow."""
credit_system = UserCredit()
user_id = f"concurrent-underflow-test-{uuid4()}"
await create_test_user(user_id)
try:
# Set up balance close to underflow threshold
from prisma.models import UserBalance
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
await UserBalance.prisma().upsert(
where={"userId": user_id},
data={
"create": {"userId": user_id, "balance": initial_balance},
"update": {"balance": initial_balance},
},
)
async def large_refund(amount: int, label: str):
try:
return await credit_system._add_transaction(
user_id=user_id,
amount=-amount,
transaction_type=CreditTransactionType.REFUND,
fail_insufficient_credits=False,
)
except Exception as e:
return f"FAILED-{label}: {e}"
# Run concurrent refunds that would cause underflow if not protected
# Each refund of 500 would cause underflow: initial_balance + (-500) could go below POSTGRES_INT_MIN
refund_amount = 500
results = await asyncio.gather(
large_refund(refund_amount, "A"),
large_refund(refund_amount, "B"),
large_refund(refund_amount, "C"),
return_exceptions=True,
)
# Check all results are valid and no underflow occurred
valid_results = []
for i, result in enumerate(results):
if isinstance(result, tuple):
balance, _ = result
assert (
balance >= POSTGRES_INT_MIN
), f"Result {i} balance {balance} underflowed below {POSTGRES_INT_MIN}"
valid_results.append(balance)
elif isinstance(result, str) and "FAILED" in result:
# Some operations might fail due to validation, that's okay
pass
else:
# Unexpected exception
assert not isinstance(
result, Exception
), f"Unexpected exception in result {i}: {result}"
# At least one operation should succeed
assert (
len(valid_results) > 0
), f"At least one refund should succeed, got results: {results}"
# All successful results should be >= POSTGRES_INT_MIN
for balance in valid_results:
assert (
balance >= POSTGRES_INT_MIN
), f"Balance {balance} should not be below {POSTGRES_INT_MIN}"
# Final balance should be valid and at or above POSTGRES_INT_MIN
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance >= POSTGRES_INT_MIN
), f"Final balance {final_balance} should not underflow below {POSTGRES_INT_MIN}"
finally:
await cleanup_test_user(user_id)

View File

@@ -0,0 +1,217 @@
"""
Integration test to verify complete migration from User.balance to UserBalance table.
This test ensures that:
1. No User.balance queries exist in the system
2. All balance operations go through UserBalance table
3. User and UserBalance stay synchronized properly
"""
import asyncio
from datetime import datetime
import pytest
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User, UserBalance
from backend.data.credit import UsageTransactionMetadata, UserCredit
from backend.util.json import SafeJson
from backend.util.test import SpinTestServer
async def create_test_user(user_id: str) -> None:
"""Create a test user for migration tests."""
try:
await User.prisma().create(
data={
"id": user_id,
"email": f"test-{user_id}@example.com",
"name": f"Test User {user_id[:8]}",
}
)
except UniqueViolationError:
# User already exists, continue
pass
async def cleanup_test_user(user_id: str) -> None:
"""Clean up test user and their data."""
try:
await CreditTransaction.prisma().delete_many(where={"userId": user_id})
await UserBalance.prisma().delete_many(where={"userId": user_id})
await User.prisma().delete_many(where={"id": user_id})
except Exception as e:
# Log cleanup failures but don't fail the test
print(f"Warning: Failed to cleanup test user {user_id}: {e}")
@pytest.mark.asyncio(loop_scope="session")
async def test_user_balance_migration_complete(server: SpinTestServer):
"""Test that User table balance is never used and UserBalance is source of truth."""
credit_system = UserCredit()
user_id = f"migration-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# 1. Verify User table does NOT have balance set initially
user = await User.prisma().find_unique(where={"id": user_id})
assert user is not None
# User.balance should not exist or should be None/0 if it exists
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
assert (
user_balance_attr == 0 or user_balance_attr is None
), f"User.balance should be 0 or None, got {user_balance_attr}"
# 2. Perform various credit operations using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "migration_test"}),
)
balance1 = await credit_system.get_credits(user_id)
assert balance1 == 1000
await credit_system.spend_credits(
user_id,
300,
UsageTransactionMetadata(
graph_exec_id="test", reason="Migration test spend"
),
)
balance2 = await credit_system.get_credits(user_id)
assert balance2 == 700
# 3. Verify UserBalance table has correct values
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 700
), f"UserBalance should be 700, got {user_balance.balance}"
# 4. CRITICAL: Verify User.balance is NEVER updated during operations
user_after = await User.prisma().find_unique(where={"id": user_id})
assert user_after is not None
user_balance_after = getattr(user_after, "balance", None)
if user_balance_after is not None:
# If User.balance exists, it should still be 0 (never updated)
assert (
user_balance_after == 0 or user_balance_after is None
), f"User.balance should remain 0/None after operations, got {user_balance_after}. This indicates User.balance is still being used!"
# 5. Verify get_credits always returns UserBalance value, not User.balance
final_balance = await credit_system.get_credits(user_id)
assert (
final_balance == user_balance.balance
), f"get_credits should return UserBalance value {user_balance.balance}, got {final_balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_detect_stale_user_balance_queries(server: SpinTestServer):
"""Test to detect if any operations are still using User.balance instead of UserBalance."""
credit_system = UserCredit()
user_id = f"stale-query-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Create UserBalance with specific value
await UserBalance.prisma().create(
data={"userId": user_id, "balance": 5000} # $50
)
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
balance = await credit_system.get_credits(user_id)
assert (
balance == 5000
), f"Expected get_credits to return 5000 from UserBalance, got {balance}"
# Verify all operations use UserBalance using internal method (bypasses Stripe)
await credit_system._add_transaction(
user_id=user_id,
amount=1000,
transaction_type=CreditTransactionType.TOP_UP,
metadata=SafeJson({"test": "final_verification"}),
)
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 6000, f"Expected 6000, got {final_balance}"
# Verify UserBalance table has the correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 6000
), f"UserBalance should be 6000, got {user_balance.balance}"
finally:
await cleanup_test_user(user_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer):
"""Test that concurrent operations all use UserBalance locking, not User.balance."""
credit_system = UserCredit()
user_id = f"concurrent-userbalance-test-{datetime.now().timestamp()}"
await create_test_user(user_id)
try:
# Set initial balance in UserBalance
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
# Run concurrent operations to ensure they all use UserBalance atomic operations
async def concurrent_spend(amount: int, label: str):
try:
await credit_system.spend_credits(
user_id,
amount,
UsageTransactionMetadata(
graph_exec_id=f"concurrent-{label}",
reason=f"Concurrent test {label}",
),
)
return f"{label}-SUCCESS"
except Exception as e:
return f"{label}-FAILED: {e}"
# Run concurrent operations
results = await asyncio.gather(
concurrent_spend(100, "A"),
concurrent_spend(200, "B"),
concurrent_spend(300, "C"),
return_exceptions=True,
)
# All should succeed (1000 >= 100+200+300)
successful = [r for r in results if "SUCCESS" in str(r)]
assert len(successful) == 3, f"All operations should succeed, got {results}"
# Final balance should be 1000 - 600 = 400
final_balance = await credit_system.get_credits(user_id)
assert final_balance == 400, f"Expected final balance 400, got {final_balance}"
# Verify UserBalance has correct value
user_balance = await UserBalance.prisma().find_unique(where={"userId": user_id})
assert user_balance is not None
assert (
user_balance.balance == 400
), f"UserBalance should be 400, got {user_balance.balance}"
# Critical: If User.balance exists and was used, it might have wrong value
try:
user = await User.prisma().find_unique(where={"id": user_id})
user_balance_attr = getattr(user, "balance", None)
if user_balance_attr is not None:
# If User.balance exists, it should NOT be used for operations
# The fact that our final balance is correct from UserBalance proves the system is working
print(
f"✅ User.balance exists ({user_balance_attr}) but UserBalance ({user_balance.balance}) is being used correctly"
)
except Exception:
print("✅ User.balance column doesn't exist - migration is complete")
finally:
await cleanup_test_user(user_id)

View File

@@ -98,42 +98,6 @@ async def transaction(timeout: int = TRANSACTION_TIMEOUT):
yield tx
@asynccontextmanager
async def locked_transaction(key: str, timeout: int = TRANSACTION_TIMEOUT):
"""
Create a transaction and take a per-key advisory *transaction* lock.
- Uses a 64-bit lock id via hashtextextended(key, 0) to avoid 32-bit collisions.
- Bound by lock_timeout and statement_timeout so it won't block indefinitely.
- Lock is held for the duration of the transaction and auto-released on commit/rollback.
Args:
key: String lock key (e.g., "usr_trx_<uuid>").
timeout: Transaction/lock/statement timeout in milliseconds.
"""
async with transaction(timeout=timeout) as tx:
# Ensure we don't wait longer than desired
# Note: SET LOCAL doesn't support parameterized queries, must use string interpolation
await tx.execute_raw(f"SET LOCAL statement_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
await tx.execute_raw(f"SET LOCAL lock_timeout = '{int(timeout)}ms'") # type: ignore[arg-type]
# Block until acquired or lock_timeout hits
try:
await tx.execute_raw(
"SELECT pg_advisory_xact_lock(hashtextextended($1, 0))",
key,
)
except Exception as e:
# Normalize PG's lock timeout error to TimeoutError for callers
if "lock timeout" in str(e).lower():
raise TimeoutError(
f"Could not acquire lock for key={key!r} within {timeout}ms"
) from e
raise
yield tx
def get_database_schema() -> str:
"""Extract database schema from DATABASE_URL."""
parsed_url = urlparse(DATABASE_URL)

View File

@@ -347,6 +347,9 @@ class APIKeyCredentials(_BaseCredentials):
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
def auth_header(self) -> str:
# Linear API keys should not have Bearer prefix
if self.provider == "linear":
return self.api_key.get_secret_value()
return f"Bearer {self.api_key.get_secret_value()}"

View File

@@ -4,7 +4,6 @@ from typing import Any, Optional
import prisma
import pydantic
from autogpt_libs.utils.cache import cached
from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
@@ -13,6 +12,7 @@ from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
from backend.util.cache import cached
from backend.util.json import SafeJson
# Mapping from user reason id to categories to search for when choosing agent to show
@@ -26,8 +26,6 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
@@ -147,7 +145,8 @@ async def reward_user(user_id: str, step: OnboardingStep):
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={

View File

@@ -0,0 +1,5 @@
import prisma.models
class StoreAgentWithRank(prisma.models.StoreAgent):
rank: float

View File

@@ -1,11 +1,11 @@
import logging
import os
from autogpt_libs.utils.cache import cached, thread_cached
from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
from backend.util.cache import cached, thread_cached
from backend.util.retry import conn_retry
load_dotenv()
@@ -34,7 +34,7 @@ def disconnect():
get_redis().close()
@cached()
@cached(ttl_seconds=3600)
def get_redis() -> Redis:
return connect()

View File

@@ -7,7 +7,6 @@ from typing import Optional, cast
from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from autogpt_libs.utils.cache import cached
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User as PrismaUser
@@ -16,6 +15,7 @@ from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.util.cache import cached
from backend.util.encryption import JSONCryptor
from backend.util.exceptions import DatabaseError
from backend.util.json import SafeJson

View File

@@ -1,5 +1,6 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
@@ -39,6 +40,7 @@ from backend.data.notifications import (
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
@@ -56,8 +58,10 @@ from backend.util.service import (
)
from backend.util.settings import Config
if TYPE_CHECKING:
from fastapi import FastAPI
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
@@ -66,23 +70,27 @@ R = TypeVar("R")
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
super().run_service()
@asynccontextmanager
async def lifespan(self, app: "FastAPI"):
async with super().lifespan(app):
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
await db.connect()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
logger.info(f"[{self.service_name}] ✅ Ready")
yield
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
await db.disconnect()
async def health_check(self) -> str:
if not db.is_connected():
@@ -145,6 +153,7 @@ class DatabaseManager(AppService):
# User Comms - async
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_by_id = _(get_user_by_id)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
@@ -230,6 +239,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_node = d.get_node
get_node_execution = d.get_node_execution
get_node_executions = d.get_node_executions
get_user_by_id = d.get_user_by_id
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output

View File

@@ -246,7 +246,7 @@ async def execute_node(
async for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_data = json.convert_pydantic_to_json(output_data)
output_data = json.to_dict(output_data)
output_size += len(json.dumps(output_data))
log_metadata.debug("Node produced output", **{output_name: output_data})
yield output_name, output_data
@@ -1548,11 +1548,12 @@ class ExecutionManager(AppProcess):
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}"
)
_ack_message(reject=True, requeue=False)
else:
logger.warning(
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
)
_ack_message(reject=True, requeue=True)
_ack_message(reject=True, requeue=True)
return
self._execution_locks[graph_exec_id] = cluster_lock
@@ -1713,6 +1714,8 @@ class ExecutionManager(AppProcess):
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()
# ------- UTILITIES ------- #

View File

@@ -248,7 +248,7 @@ class Scheduler(AppService):
raise UnhealthyServiceError("Scheduler is still initializing")
# Check if we're in the middle of cleanup
if self.cleaned_up:
if self._shutting_down:
return await super().health_check()
# Normal operation - check if scheduler is running
@@ -375,7 +375,6 @@ class Scheduler(AppService):
super().run_service()
def cleanup(self):
super().cleanup()
if self.scheduler:
logger.info("⏳ Shutting down scheduler...")
self.scheduler.shutdown(wait=True)
@@ -390,7 +389,7 @@ class Scheduler(AppService):
logger.info("⏳ Waiting for event loop thread to finish...")
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
logger.info("Scheduler cleanup complete.")
super().cleanup()
@expose
def add_graph_execution_schedule(

View File

@@ -34,6 +34,7 @@ from backend.data.graph import GraphModel, Node
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import get_user_by_id
from backend.util.cache import cached
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
@@ -41,11 +42,12 @@ from backend.util.clients import (
get_integration_credentials_store,
)
from backend.util.exceptions import GraphValidationError, NotFoundError
from backend.util.logging import TruncatedLogger
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
from backend.util.settings import Config
from backend.util.type import convert
@cached(maxsize=1000, ttl_seconds=3600)
async def get_user_context(user_id: str) -> UserContext:
"""
Get UserContext for a user, always returns a valid context with timezone.
@@ -53,7 +55,11 @@ async def get_user_context(user_id: str) -> UserContext:
"""
user_context = UserContext(timezone="UTC") # Default to UTC
try:
user = await get_user_by_id(user_id)
if prisma.is_connected():
user = await get_user_by_id(user_id)
else:
user = await get_database_manager_async_client().get_user_by_id(user_id)
if user and user.timezone and user.timezone != "not-set":
user_context.timezone = user.timezone
logger.debug(f"Retrieved user context: timezone={user.timezone}")
@@ -93,7 +99,11 @@ class LogMetadata(TruncatedLogger):
"node_id": node_id,
"block_name": block_name,
}
prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
prefix = (
"[ExecutionManager]"
if is_structured_logging_enabled()
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
)
super().__init__(
logger,
max_length=max_length,

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from autogpt_libs.utils.cache import cached
from backend.util.cache import cached
if TYPE_CHECKING:
from ..providers import ProviderName
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
# --8<-- [start:load_webhook_managers]
@cached()
@cached(ttl_seconds=3600)
def load_webhook_managers() -> dict["ProviderName", type["BaseWebhooksManager"]]:
webhook_managers = {}

View File

@@ -1017,10 +1017,14 @@ class NotificationManager(AppService):
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
raise
@continuous_retry()
def run_service(self):
self.run_and_wait(self._run_service())
# Queue the main _run_service task
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
# Start the main event loop
super().run_service()
@continuous_retry()
async def _run_service(self):
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
@@ -1086,10 +1090,11 @@ class NotificationManager(AppService):
def cleanup(self):
"""Cleanup service resources"""
self.running = False
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
logger.info("⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
super().cleanup()
class NotificationManagerClient(AppServiceClient):
@classmethod

View File

@@ -14,19 +14,49 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "test-user-id"
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "admin-user-id"
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "target-user-id"
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture

View File

@@ -321,10 +321,6 @@ class AgentServer(backend.util.service.AppProcess):
uvicorn.run(**uvicorn_config)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
@staticmethod
async def test_execute_graph(
graph_id: str,

View File

@@ -11,7 +11,6 @@ import pydantic
import stripe
from autogpt_libs.auth import get_user_id, requires_user
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from autogpt_libs.utils.cache import cached
from fastapi import (
APIRouter,
Body,
@@ -40,6 +39,7 @@ from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
TransactionHistory,
UserCredit,
get_auto_top_up,
get_user_credit_model,
set_auto_top_up,
@@ -84,6 +84,7 @@ from backend.server.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.util.cache import cached
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
@@ -107,9 +108,6 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
settings = Settings()
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
# Define the API routes
v1_router = APIRouter()
@@ -291,7 +289,7 @@ def _compute_blocks_sync() -> str:
return dumps(result)
@cached()
@cached(ttl_seconds=3600)
async def _get_cached_blocks() -> str:
"""
Async cached function with thundering herd protection.
@@ -478,7 +476,8 @@ async def upload_file(
async def get_user_credits(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, int]:
return {"credits": await _user_credit_model.get_credits(user_id)}
user_credit_model = await get_user_credit_model(user_id)
return {"credits": await user_credit_model.get_credits(user_id)}
@v1_router.post(
@@ -490,9 +489,8 @@ async def get_user_credits(
async def request_top_up(
request: RequestTopUp, user_id: Annotated[str, Security(get_user_id)]
):
checkout_url = await _user_credit_model.top_up_intent(
user_id, request.credit_amount
)
user_credit_model = await get_user_credit_model(user_id)
checkout_url = await user_credit_model.top_up_intent(user_id, request.credit_amount)
return {"checkout_url": checkout_url}
@@ -507,7 +505,8 @@ async def refund_top_up(
transaction_key: str,
metadata: dict[str, str],
) -> int:
return await _user_credit_model.top_up_refund(user_id, transaction_key, metadata)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.top_up_refund(user_id, transaction_key, metadata)
@v1_router.patch(
@@ -517,7 +516,8 @@ async def refund_top_up(
dependencies=[Security(requires_user)],
)
async def fulfill_checkout(user_id: Annotated[str, Security(get_user_id)]):
await _user_credit_model.fulfill_checkout(user_id=user_id)
user_credit_model = await get_user_credit_model(user_id)
await user_credit_model.fulfill_checkout(user_id=user_id)
return Response(status_code=200)
@@ -531,18 +531,23 @@ async def configure_user_auto_top_up(
request: AutoTopUpConfig, user_id: Annotated[str, Security(get_user_id)]
) -> str:
if request.threshold < 0:
raise ValueError("Threshold must be greater than 0")
raise HTTPException(status_code=422, detail="Threshold must be greater than 0")
if request.amount < 500 and request.amount != 0:
raise ValueError("Amount must be greater than or equal to 500")
if request.amount < request.threshold:
raise ValueError("Amount must be greater than or equal to threshold")
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to 500"
)
if request.amount != 0 and request.amount < request.threshold:
raise HTTPException(
status_code=422, detail="Amount must be greater than or equal to threshold"
)
current_balance = await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance < request.threshold:
await _user_credit_model.top_up_credits(user_id, request.amount)
await user_credit_model.top_up_credits(user_id, request.amount)
else:
await _user_credit_model.top_up_credits(user_id, 0)
await user_credit_model.top_up_credits(user_id, 0)
await set_auto_top_up(
user_id, AutoTopUpConfig(threshold=request.threshold, amount=request.amount)
@@ -590,15 +595,13 @@ async def stripe_webhook(request: Request):
event["type"] == "checkout.session.completed"
or event["type"] == "checkout.session.async_payment_succeeded"
):
await _user_credit_model.fulfill_checkout(
session_id=event["data"]["object"]["id"]
)
await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"])
if event["type"] == "charge.dispute.created":
await _user_credit_model.handle_dispute(event["data"]["object"])
await UserCredit().handle_dispute(event["data"]["object"])
if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed":
await _user_credit_model.deduct_credits(event["data"]["object"])
await UserCredit().deduct_credits(event["data"]["object"])
return Response(status_code=200)
@@ -612,7 +615,8 @@ async def stripe_webhook(request: Request):
async def manage_payment_method(
user_id: Annotated[str, Security(get_user_id)],
) -> dict[str, str]:
return {"url": await _user_credit_model.create_billing_portal_session(user_id)}
user_credit_model = await get_user_credit_model(user_id)
return {"url": await user_credit_model.create_billing_portal_session(user_id)}
@v1_router.get(
@@ -630,7 +634,8 @@ async def get_credit_history(
if transaction_count_limit < 1 or transaction_count_limit > 1000:
raise ValueError("Transaction count limit must be between 1 and 1000")
return await _user_credit_model.get_transaction_history(
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_transaction_history(
user_id=user_id,
transaction_time_ceiling=transaction_time,
transaction_count_limit=transaction_count_limit,
@@ -647,7 +652,8 @@ async def get_credit_history(
async def get_refund_requests(
user_id: Annotated[str, Security(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
user_credit_model = await get_user_credit_model(user_id)
return await user_credit_model.get_refund_requests(user_id)
########################################################
@@ -869,7 +875,8 @@ async def execute_graph(
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> execution_db.GraphExecutionMeta:
current_balance = await _user_credit_model.get_credits(user_id)
user_credit_model = await get_user_credit_model(user_id)
current_balance = await user_credit_model.get_credits(user_id)
if current_balance <= 0:
raise HTTPException(
status_code=402,

View File

@@ -23,10 +23,13 @@ client = fastapi.testclient.TestClient(app)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
def setup_app_auth(mock_jwt_user, setup_test_user):
"""Setup auth overrides for all tests in this module"""
from autogpt_libs.auth.jwt_utils import get_jwt_payload
# setup_test_user fixture already executed and user is created in database
# It returns the user_id which we don't need to await
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
@@ -194,8 +197,12 @@ def test_get_user_credits(
snapshot: Snapshot,
) -> None:
"""Test get user credits endpoint"""
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model = Mock()
mock_credit_model.get_credits = AsyncMock(return_value=1000)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
response = client.get("/credits")
@@ -215,10 +222,14 @@ def test_request_top_up(
snapshot: Snapshot,
) -> None:
"""Test request top up endpoint"""
mock_credit_model = mocker.patch("backend.server.routers.v1._user_credit_model")
mock_credit_model = Mock()
mock_credit_model.top_up_intent = AsyncMock(
return_value="https://checkout.example.com/session123"
)
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {"credit_amount": 500}
@@ -261,6 +272,74 @@ def test_get_auto_top_up(
)
def test_configure_auto_top_up(
mocker: pytest_mock.MockFixture,
snapshot: Snapshot,
) -> None:
"""Test configure auto top-up endpoint - this test would have caught the enum casting bug"""
# Mock the set_auto_top_up function to avoid database operations
mocker.patch(
"backend.server.routers.v1.set_auto_top_up",
return_value=None,
)
# Mock credit model to avoid Stripe API calls
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50 # Current balance below threshold
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test data
request_data = {
"threshold": 100,
"amount": 500,
}
response = client.post("/credits/auto-top-up", json=request_data)
# This should succeed with our fix, but would have failed before with the enum casting error
assert response.status_code == 200
assert response.json() == "Auto top-up settings updated"
def test_configure_auto_top_up_validation_errors(
mocker: pytest_mock.MockFixture,
) -> None:
"""Test configure auto top-up endpoint validation"""
# Mock set_auto_top_up to avoid database operations for successful case
mocker.patch("backend.server.routers.v1.set_auto_top_up")
# Mock credit model to avoid Stripe API calls for the successful case
mock_credit_model = mocker.AsyncMock()
mock_credit_model.get_credits.return_value = 50
mock_credit_model.top_up_credits.return_value = None
mocker.patch(
"backend.server.routers.v1.get_user_credit_model",
return_value=mock_credit_model,
)
# Test negative threshold
response = client.post(
"/credits/auto-top-up", json={"threshold": -1, "amount": 500}
)
assert response.status_code == 422 # Validation error
# Test amount too small (but not 0)
response = client.post(
"/credits/auto-top-up", json={"threshold": 100, "amount": 100}
)
assert response.status_code == 422 # Validation error
# Test amount = 0 (should be allowed)
response = client.post("/credits/auto-top-up", json={"threshold": 100, "amount": 0})
assert response.status_code == 200 # Should succeed
# Graphs endpoints tests
def test_get_graphs(
mocker: pytest_mock.MockFixture,

View File

@@ -11,8 +11,6 @@ from backend.util.json import SafeJson
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
@@ -33,7 +31,8 @@ async def add_user_credits(
logger.info(
f"Admin user {admin_user_id} is adding {amount} credits to user {user_id}"
)
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_credit_model = await get_user_credit_model(user_id)
new_balance, transaction_key = await user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,

View File

@@ -1,5 +1,5 @@
import json
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock
import fastapi
import fastapi.testclient
@@ -7,12 +7,12 @@ import prisma.enums
import pytest
import pytest_mock
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from prisma import Json
from pytest_snapshot.plugin import Snapshot
import backend.server.v2.admin.credit_admin_routes as credit_admin_routes
import backend.server.v2.admin.model as admin_model
from backend.data.model import UserTransaction
from backend.util.json import SafeJson
from backend.util.models import Pagination
app = fastapi.FastAPI()
@@ -37,12 +37,14 @@ def test_add_user_credits_success(
) -> None:
"""Test successful credit addition by admin"""
# Mock the credit model
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model = Mock()
mock_credit_model._add_transaction = AsyncMock(
return_value=(1500, "transaction-123-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": target_user_id,
@@ -62,11 +64,17 @@ def test_add_user_credits_success(
call_args = mock_credit_model._add_transaction.call_args
assert call_args[0] == (target_user_id, 500)
assert call_args[1]["transaction_type"] == prisma.enums.CreditTransactionType.GRANT
# Check that metadata is a Json object with the expected content
assert isinstance(call_args[1]["metadata"], Json)
assert call_args[1]["metadata"] == Json(
{"admin_id": admin_user_id, "reason": "Test credit grant for debugging"}
)
# Check that metadata is a SafeJson object with the expected content
assert isinstance(call_args[1]["metadata"], SafeJson)
actual_metadata = call_args[1]["metadata"]
expected_data = {
"admin_id": admin_user_id,
"reason": "Test credit grant for debugging",
}
# SafeJson inherits from Json which stores parsed data in the .data attribute
assert actual_metadata.data["admin_id"] == expected_data["admin_id"]
assert actual_metadata.data["reason"] == expected_data["reason"]
# Snapshot test the response
configured_snapshot.assert_match(
@@ -81,12 +89,14 @@ def test_add_user_credits_negative_amount(
) -> None:
"""Test credit deduction by admin (negative amount)"""
# Mock the credit model
mock_credit_model = mocker.patch(
"backend.server.v2.admin.credit_admin_routes._user_credit_model"
)
mock_credit_model = Mock()
mock_credit_model._add_transaction = AsyncMock(
return_value=(200, "transaction-456-uuid")
)
mocker.patch(
"backend.server.v2.admin.credit_admin_routes.get_user_credit_model",
return_value=mock_credit_model,
)
request_data = {
"user_id": "target-user-id",

View File

@@ -2,7 +2,6 @@ import logging
from datetime import datetime, timedelta, timezone
import prisma
from autogpt_libs.utils.cache import cached
import backend.data.block
from backend.blocks import load_all_blocks
@@ -18,6 +17,7 @@ from backend.server.v2.builder.model import (
ProviderResponse,
SearchBlocksResponse,
)
from backend.util.cache import cached
from backend.util.models import Pagination
logger = logging.getLogger(__name__)
@@ -307,7 +307,7 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
return False
@cached()
@cached(ttl_seconds=3600)
def _get_all_providers() -> dict[ProviderName, Provider]:
providers: dict[ProviderName, Provider] = {}

View File

@@ -1,6 +1,5 @@
from autogpt_libs.utils.cache import cached
import backend.server.v2.store.db
from backend.util.cache import cached
##############################################
############### Caches #######################
@@ -17,7 +16,7 @@ def clear_all_caches():
# Cache store agents list for 5 minutes
# Different cache entries for different query combinations
@cached(maxsize=5000, ttl_seconds=300)
@cached(maxsize=5000, ttl_seconds=300, shared_cache=True)
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
@@ -40,7 +39,7 @@ async def _get_cached_store_agents(
# Cache individual agent details for 15 minutes
@cached(maxsize=200, ttl_seconds=300)
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
return await backend.server.v2.store.db.get_store_agent_details(
@@ -49,7 +48,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
# Cache creators list for 5 minutes
@cached(maxsize=200, ttl_seconds=300)
@cached(maxsize=200, ttl_seconds=300, shared_cache=True)
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
@@ -68,7 +67,7 @@ async def _get_cached_store_creators(
# Cache individual creator details for 5 minutes
@cached(maxsize=100, ttl_seconds=300)
@cached(maxsize=100, ttl_seconds=300, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await backend.server.v2.store.db.get_store_creator_details(

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import typing
from datetime import datetime, timezone
import fastapi
@@ -71,64 +72,199 @@ async def get_store_agents(
logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
search_term = sanitize_query(search_query)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
sanitized_creators = []
if creators:
where_clause["creator_username"] = {"in": creators}
for c in creators:
sanitized_creators.append(sanitize_query(c))
sanitized_category = None
if category:
where_clause["categories"] = {"has": category}
if search_term:
where_clause["OR"] = [
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
{"description": {"contains": search_term, "mode": "insensitive"}},
]
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
elif sorted_by == "runs":
order_by.append({"runs": "desc"})
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
sanitized_category = sanitize_query(category)
try:
agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
# If search_query is provided, use full-text search
if search_query:
search_term = sanitize_query(search_query)
if not search_term:
# Return empty results for invalid search query
return backend.server.v2.store.model.StoreAgentsResponse(
agents=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
except Exception as e:
# Skip this agent if there was an error
# You could log the error here if needed
logger.error(
f"Error parsing Store agent when getting store agents from db: {e}"
)
continue
offset = (page - 1) * page_size
# Whitelist allowed order_by columns
ALLOWED_ORDER_BY = {
"rating": "rating DESC, rank DESC",
"runs": "runs DESC, rank DESC",
"name": "agent_name ASC, rank DESC",
"updated_at": "updated_at DESC, rank DESC",
}
# Validate and get order clause
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
else:
order_by_clause = "updated_at DESC, rank DESC",
# Build WHERE conditions and parameters list
where_parts: list[str] = []
params: list[typing.Any] = [search_term] # $1 - search term
param_index = 2 # Start at $2 for next parameter
# Always filter for available agents
where_parts.append("is_available = true")
if featured:
where_parts.append("featured = true")
if creators and sanitized_creators:
# Use ANY with array parameter
where_parts.append(f"creator_username = ANY(${param_index})")
params.append(sanitized_creators)
param_index += 1
if category and sanitized_category:
where_parts.append(f"${param_index} = ANY(categories)")
params.append(sanitized_category)
param_index += 1
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
# Add pagination params
params.extend([page_size, offset])
limit_param = f"${param_index}"
offset_param = f"${param_index + 1}"
# Execute full-text search query with parameterized values
sql_query = f"""
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank_cd(search, query) AS rank
FROM "StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
ORDER BY {order_by_clause}
LIMIT {limit_param} OFFSET {offset_param}
"""
# Count query for pagination - only uses search term parameter
count_query = f"""
SELECT COUNT(*) as count
FROM "StoreAgent",
plainto_tsquery('english', $1) AS query
WHERE {sql_where_clause}
AND search @@ query
"""
# Execute both queries with parameters
agents = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, sql_query), *params
)
# For count, use params without pagination (last 2 params)
count_params = params[:-2]
count_result = await prisma.client.get_client().query_raw(
typing.cast(typing.LiteralString, count_query), *count_params
)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(f"Error parsing Store agent from search results: {e}")
continue
else:
# Non-search query path (original logic)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
if creators:
where_clause["creator_username"] = {"in": sanitized_creators}
if sanitized_category:
where_clause["categories"] = {"has": sanitized_category}
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
elif sorted_by == "runs":
order_by.append({"runs": "desc"})
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
except Exception as e:
# Skip this agent if there was an error
# You could log the error here if needed
logger.error(
f"Error parsing Store agent when getting store agents from db: {e}"
)
continue
logger.debug(f"Found {len(store_agents)} agents")
return backend.server.v2.store.model.StoreAgentsResponse(

View File

@@ -20,7 +20,7 @@ async def setup_prisma():
yield
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agents(mocker):
# Mock data
mock_agents = [
@@ -64,7 +64,7 @@ async def test_get_store_agents(mocker):
mock_store_agent.return_value.count.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agent_details(mocker):
# Mock data
mock_agent = prisma.models.StoreAgent(
@@ -173,7 +173,7 @@ async def test_get_store_agent_details(mocker):
mock_store_listing_db.return_value.find_first.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_creator_details(mocker):
# Mock data
mock_creator_data = prisma.models.Creator(
@@ -210,7 +210,7 @@ async def test_get_store_creator_details(mocker):
)
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_create_store_submission(mocker):
# Mock data
mock_agent = prisma.models.AgentGraph(
@@ -282,7 +282,7 @@ async def test_create_store_submission(mocker):
mock_store_listing.return_value.create.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_update_profile(mocker):
# Mock data
mock_profile = prisma.models.Profile(
@@ -327,7 +327,7 @@ async def test_update_profile(mocker):
mock_profile_db.return_value.update.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_get_user_profile(mocker):
# Mock data
mock_profile = prisma.models.Profile(
@@ -359,3 +359,63 @@ async def test_get_user_profile(mocker):
assert result.description == "Test description"
assert result.links == ["link1", "link2"]
assert result.avatar_url == "avatar.jpg"
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agents_with_search_parameterized(mocker):
"""Test that search query uses parameterized SQL - validates the fix works"""
# Call function with search query containing potential SQL injection
malicious_search = "test'; DROP TABLE StoreAgent; --"
result = await db.get_store_agents(search_query=malicious_search)
# Verify query executed safely
assert isinstance(result.agents, list)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agents_with_search_and_filters_parameterized():
"""Test parameterized SQL with multiple filters"""
# Call with multiple filters including potential injection attempts
result = await db.get_store_agents(
search_query="test",
creators=["creator1'; DROP TABLE Users; --", "creator2"],
category="AI'; DELETE FROM StoreAgent; --",
featured=True,
sorted_by="rating",
page=1,
page_size=20,
)
# Verify the query executed without error
assert isinstance(result.agents, list)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agents_search_with_invalid_sort_by():
"""Test that invalid sorted_by value doesn't cause SQL injection""" # Try to inject SQL via sorted_by parameter
malicious_sort = "rating; DROP TABLE Users; --"
result = await db.get_store_agents(
search_query="test",
sorted_by=malicious_sort,
)
# Verify the query executed without error
# Invalid sort_by should fall back to default, not cause SQL injection
assert isinstance(result.agents, list)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_store_agents_search_category_array_injection():
"""Test that category parameter is safely passed as a parameter"""
# Try SQL injection via category
malicious_category = "AI'; DROP TABLE StoreAgent; --"
result = await db.get_store_agents(
search_query="test",
category=malicious_category,
)
# Verify the query executed without error
# Category should be parameterized, preventing SQL injection
assert isinstance(result.agents, list)

View File

@@ -40,23 +40,13 @@ async def get_profile(
Get the profile details for the authenticated user.
Cached for 1 hour per user.
"""
try:
profile = await backend.server.v2.store.db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=404,
content={"detail": "Profile not found"},
)
return profile
except Exception as e:
logger.exception("Failed to fetch user profile for %s: %s", user_id, e)
profile = await backend.server.v2.store.db.get_user_profile(user_id)
if profile is None:
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "Failed to retrieve user profile",
"hint": "Check database connection.",
},
status_code=404,
content={"detail": "Profile not found"},
)
return profile
@router.post(
@@ -83,20 +73,10 @@ async def update_or_create_profile(
Raises:
HTTPException: If there is an error updating the profile
"""
try:
updated_profile = await backend.server.v2.store.db.update_profile(
user_id=user_id, profile=profile
)
return updated_profile
except Exception as e:
logger.exception("Failed to update profile for user %s: %s", user_id, e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "Failed to update user profile",
"hint": "Validate request data.",
},
)
updated_profile = await backend.server.v2.store.db.update_profile(
user_id=user_id, profile=profile
)
return updated_profile
##############################################
@@ -155,26 +135,16 @@ async def get_agents(
status_code=422, detail="Page size must be greater than 0"
)
try:
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
except Exception as e:
logger.exception("Failed to retrieve store agents: %s", e)
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "Failed to retrieve store agents",
"hint": "Check database or search parameters.",
},
)
agents = await store_cache._get_cached_store_agents(
featured=featured,
creator=creator,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
return agents
@router.get(
@@ -189,22 +159,13 @@ async def get_agent(username: str, agent_name: str):
It returns the store listing agents details.
"""
try:
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return agent
except Exception:
logger.exception("Exception occurred whilst getting store agent details")
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving the store agent details"
},
)
username = urllib.parse.unquote(username).lower()
# URL decode the agent name since it comes from the URL path
agent_name = urllib.parse.unquote(agent_name).lower()
agent = await store_cache._get_cached_agent_details(
username=username, agent_name=agent_name
)
return agent
@router.get(
@@ -217,17 +178,10 @@ async def get_graph_meta_by_store_listing_version_id(store_listing_version_id: s
"""
Get Agent Graph from Store Listing Version ID.
"""
try:
graph = await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
return graph
except Exception:
logger.exception("Exception occurred whilst getting agent graph")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while retrieving the agent graph"},
)
graph = await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
return graph
@router.get(
@@ -241,18 +195,11 @@ async def get_store_agent(store_listing_version_id: str):
"""
Get Store Agent Details from Store Listing Version ID.
"""
try:
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
agent = await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
return agent
except Exception:
logger.exception("Exception occurred whilst getting store agent")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while retrieving the store agent"},
)
return agent
@router.post(
@@ -280,24 +227,17 @@ async def create_review(
Returns:
The created review
"""
try:
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await backend.server.v2.store.db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
username = urllib.parse.unquote(username).lower()
agent_name = urllib.parse.unquote(agent_name).lower()
# Create the review
created_review = await backend.server.v2.store.db.create_store_review(
user_id=user_id,
store_listing_version_id=review.store_listing_version_id,
score=review.score,
comments=review.comments,
)
return created_review
except Exception:
logger.exception("Exception occurred whilst creating store review")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while creating the store review"},
)
return created_review
##############################################
@@ -340,21 +280,14 @@ async def get_creators(
status_code=422, detail="Page size must be greater than 0"
)
try:
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return creators
except Exception:
logger.exception("Exception occurred whilst getting store creators")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while retrieving the store creators"},
)
creators = await store_cache._get_cached_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
return creators
@router.get(
@@ -370,18 +303,9 @@ async def get_creator(
Get the details of a creator.
- Creator Details Page
"""
try:
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
except Exception:
logger.exception("Exception occurred whilst getting creator details")
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving the creator details"
},
)
username = urllib.parse.unquote(username).lower()
creator = await store_cache._get_cached_creator_details(username=username)
return creator
############################################
@@ -404,17 +328,10 @@ async def get_my_agents(
"""
Get user's own agents.
"""
try:
agents = await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
return agents
except Exception:
logger.exception("Exception occurred whilst getting my agents")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while retrieving the my agents"},
)
agents = await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
return agents
@router.delete(
@@ -438,19 +355,12 @@ async def delete_submission(
Returns:
bool: True if the submission was successfully deleted, False otherwise
"""
try:
result = await backend.server.v2.store.db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
result = await backend.server.v2.store.db.delete_store_submission(
user_id=user_id,
submission_id=submission_id,
)
return result
except Exception:
logger.exception("Exception occurred whilst deleting store submission")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while deleting the store submission"},
)
return result
@router.get(
@@ -488,21 +398,12 @@ async def get_submissions(
raise fastapi.HTTPException(
status_code=422, detail="Page size must be greater than 0"
)
try:
listings = await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)
return listings
except Exception:
logger.exception("Exception occurred whilst getting store submissions")
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while retrieving the store submissions"
},
)
listings = await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)
return listings
@router.post(
@@ -529,36 +430,23 @@ async def create_submission(
Raises:
HTTPException: If there is an error creating the submission
"""
try:
result = await backend.server.v2.store.db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
image_urls=submission_request.image_urls,
description=submission_request.description,
instructions=submission_request.instructions,
sub_heading=submission_request.sub_heading,
categories=submission_request.categories,
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
result = await backend.server.v2.store.db.create_store_submission(
user_id=user_id,
agent_id=submission_request.agent_id,
agent_version=submission_request.agent_version,
slug=submission_request.slug,
name=submission_request.name,
video_url=submission_request.video_url,
image_urls=submission_request.image_urls,
description=submission_request.description,
instructions=submission_request.instructions,
sub_heading=submission_request.sub_heading,
categories=submission_request.categories,
changes_summary=submission_request.changes_summary or "Initial Submission",
recommended_schedule_cron=submission_request.recommended_schedule_cron,
)
return result
except backend.server.v2.store.exceptions.SlugAlreadyInUseError as e:
logger.warning("Slug already in use: %s", str(e))
return fastapi.responses.JSONResponse(
status_code=409,
content={"detail": str(e)},
)
except Exception:
logger.exception("Exception occurred whilst creating store submission")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while creating the store submission"},
)
return result
@router.put(
@@ -627,36 +515,10 @@ async def upload_submission_media(
Raises:
HTTPException: If there is an error uploading the media
"""
try:
media_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=file
)
return media_url
except backend.server.v2.store.exceptions.VirusDetectedError as e:
logger.warning(f"Virus detected in uploaded file: {e.threat_name}")
return fastapi.responses.JSONResponse(
status_code=400,
content={
"detail": f"File rejected due to virus detection: {e.threat_name}",
"error_type": "virus_detected",
"threat_name": e.threat_name,
},
)
except backend.server.v2.store.exceptions.VirusScanError as e:
logger.error(f"Virus scanning failed: {str(e)}")
return fastapi.responses.JSONResponse(
status_code=503,
content={
"detail": "Virus scanning service unavailable. Please try again later.",
"error_type": "virus_scan_failed",
},
)
except Exception:
logger.exception("Exception occurred whilst uploading submission media")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while uploading the media file"},
)
media_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=file
)
return media_url
@router.post(
@@ -679,44 +541,35 @@ async def generate_image(
Returns:
JSONResponse: JSON containing the URL of the generated image
"""
try:
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"
existing_url = await backend.server.v2.store.media.check_media_exists(
user_id, filename
)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
# Generate agent image as JPEG
image = await backend.server.v2.store.image_gen.generate_agent_image(
agent=agent
)
existing_url = await backend.server.v2.store.media.check_media_exists(
user_id, filename
)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
# Generate agent image as JPEG
image = await backend.server.v2.store.image_gen.generate_agent_image(agent=agent)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)
image_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
image_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
except Exception:
logger.exception("Exception occurred whilst generating submission image")
return fastapi.responses.JSONResponse(
status_code=500,
content={"detail": "An error occurred while generating the image"},
)
return fastapi.responses.JSONResponse(content={"image_url": image_url})
@router.get(

View File

@@ -329,7 +329,3 @@ class WebsocketServer(AppProcess):
port=Config().websocket_server_port,
log_config=None,
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")

View File

@@ -0,0 +1,457 @@
"""
Caching utilities for the AutoGPT platform.
Provides decorators for caching function results with support for:
- In-memory caching with TTL
- Shared Redis-backed caching across processes
- Thread-local caching for request-scoped data
- Thundering herd protection
- LRU eviction with optional TTL refresh
"""
import asyncio
import inspect
import logging
import pickle
import threading
import time
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, ParamSpec, Protocol, TypeVar, cast, runtime_checkable
from redis import ConnectionPool, Redis
from backend.util.retry import conn_retry
from backend.util.settings import Settings
P = ParamSpec("P")
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
settings = Settings()
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
# Configure Redis with the following settings for optimal caching performance:
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
# maxmemory 2gb # Set memory limit (adjust based on your needs)
# save "" # Disable persistence if using Redis purely for caching
# Create a dedicated Redis connection pool for caching (binary mode for pickle)
_cache_pool: ConnectionPool | None = None
@conn_retry("Redis", "Acquiring cache connection pool")
def _get_cache_pool() -> ConnectionPool:
"""Get or create a connection pool for cache operations."""
global _cache_pool
if _cache_pool is None:
_cache_pool = ConnectionPool(
host=settings.config.redis_host,
port=settings.config.redis_port,
password=settings.config.redis_password or None,
decode_responses=False, # Binary mode for pickle
max_connections=50,
socket_keepalive=True,
socket_connect_timeout=5,
retry_on_timeout=True,
)
return _cache_pool
redis = Redis(connection_pool=_get_cache_pool())
@dataclass
class CachedValue:
"""Wrapper for cached values with timestamp to avoid tuple ambiguity."""
result: Any
timestamp: float
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[Any, ...]:
"""
Convert args and kwargs into a hashable cache key.
Handles unhashable types like dict, list, set by converting them to
their sorted string representations.
"""
def make_hashable(obj: Any) -> Any:
"""Recursively convert an object to a hashable representation."""
if isinstance(obj, dict):
# Sort dict items to ensure consistent ordering
return (
"__dict__",
tuple(sorted((k, make_hashable(v)) for k, v in obj.items())),
)
elif isinstance(obj, (list, tuple)):
return ("__list__", tuple(make_hashable(item) for item in obj))
elif isinstance(obj, set):
return ("__set__", tuple(sorted(make_hashable(item) for item in obj)))
elif hasattr(obj, "__dict__"):
# Handle objects with __dict__ attribute
return ("__obj__", obj.__class__.__name__, make_hashable(obj.__dict__))
else:
# For basic hashable types (str, int, bool, None, etc.)
try:
hash(obj)
return obj
except TypeError:
# Fallback: convert to string representation
return ("__str__", str(obj))
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = tuple(sorted((k, make_hashable(v)) for k, v in kwargs.items()))
return (hashable_args, hashable_kwargs)
def _make_redis_key(key: tuple[Any, ...], func_name: str) -> str:
"""Convert a hashable key tuple to a Redis key string."""
# Ensure key is already hashable
hashable_key = key if isinstance(key, tuple) else (key,)
return f"cache:{func_name}:{hash(hashable_key)}"
@runtime_checkable
class CachedFunction(Protocol[P, R_co]):
"""Protocol for cached functions with cache management methods."""
def cache_clear(self, pattern: str | None = None) -> None:
"""Clear cached entries. If pattern provided, clear matching entries only."""
return None
def cache_info(self) -> dict[str, int | None]:
"""Get cache statistics."""
return {}
def cache_delete(self, *args: P.args, **kwargs: P.kwargs) -> bool:
"""Delete a specific cache entry by its arguments. Returns True if entry existed."""
return False
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
"""Call the cached function."""
return None # type: ignore
def cached(
*,
maxsize: int = 128,
ttl_seconds: int,
shared_cache: bool = False,
refresh_ttl_on_get: bool = False,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
Args:
maxsize: Maximum number of cached entries (only for in-memory cache)
ttl_seconds: Time to live in seconds. Required - entries must expire.
shared_cache: If True, use Redis for cross-process caching
refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior)
Returns:
Decorated function with caching capabilities
Example:
@cached(ttl_seconds=300) # 5 minute TTL
def expensive_sync_operation(param: str) -> dict:
return {"result": param}
@cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True)
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
"""
def decorator(target_func):
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any | None:
"""Get value from Redis, optionally refreshing TTL."""
try:
if refresh_ttl_on_get:
# Use GETEX to get value and refresh expiry atomically
cached_bytes = redis.getex(redis_key, ex=ttl_seconds)
else:
cached_bytes = redis.get(redis_key)
if cached_bytes and isinstance(cached_bytes, bytes):
return pickle.loads(cached_bytes)
except Exception as e:
logger.error(
f"Redis error during cache check for {target_func.__name__}: {e}"
)
return None
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set value in Redis with TTL."""
try:
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
redis.setex(redis_key, ttl_seconds, pickled_value)
except Exception as e:
logger.error(
f"Redis error storing cache for {target_func.__name__}: {e}"
)
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
if key in cache_storage:
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return None
def _set_to_memory(key: tuple, value: Any) -> None:
"""Set value in in-memory cache with timestamp."""
cache_storage[key] = CachedValue(result=value, timestamp=time.time())
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
if inspect.iscoroutinefunction(target_func):
def _get_cache_lock():
"""Get or create an asyncio.Lock for the current event loop."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop not in _event_loop_locks:
return _event_loop_locks.setdefault(loop, asyncio.Lock())
return _event_loop_locks[loop]
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not None:
return result
# Slow path: acquire lock for cache miss/expiry
async with _get_cache_lock():
# Double-check: another coroutine might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not None:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
wrapper = async_wrapper
else:
# Sync function with threading.Lock
cache_lock = threading.Lock()
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
# Fast path: check cache without lock
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not None:
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if shared_cache:
result = _get_from_redis(redis_key)
if result is not None:
return result
else:
result = _get_from_memory(key)
if result is not None:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if shared_cache:
_set_to_redis(redis_key, result)
else:
_set_to_memory(key, result)
return result
wrapper = sync_wrapper
# Add cache management methods
def cache_clear(pattern: str | None = None) -> None:
"""Clear cache entries. If pattern provided, clear matching entries."""
if shared_cache:
if pattern:
# Clear entries matching pattern
keys = list(
redis.scan_iter(f"cache:{target_func.__name__}:{pattern}")
)
else:
# Clear all cache keys
keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
if keys:
pipeline = redis.pipeline()
for key in keys:
pipeline.delete(key)
pipeline.execute()
else:
if pattern:
# For in-memory cache, pattern matching not supported
logger.warning(
"Pattern-based clearing not supported for in-memory cache"
)
else:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
if shared_cache:
cache_keys = list(redis.scan_iter(f"cache:{target_func.__name__}:*"))
return {
"size": len(cache_keys),
"maxsize": None, # Redis manages its own size
"ttl_seconds": ttl_seconds,
}
else:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_key = _make_redis_key(key, target_func.__name__)
if redis.exists(redis_key):
redis.delete(redis_key)
return True
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
setattr(wrapper, "cache_delete", cache_delete)
return cast(CachedFunction, wrapper)
return decorator
def thread_cached(func):
"""
Thread-local cache decorator for both sync and async functions.
Each thread gets its own cache, which is useful for request-scoped caching
in web applications where you want to cache within a single request but
not across requests.
Args:
func: The function to cache
Returns:
Decorated function with thread-local caching
Example:
@thread_cached
def expensive_operation(param: str) -> dict:
return {"result": param}
@thread_cached # Works with async too
async def expensive_async_operation(param: str) -> dict:
return {"result": param}
"""
thread_local = threading.local()
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = await func(*args, **kwargs)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = _make_hashable_key(args, kwargs)
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
"""Clear thread-local cache for a function."""
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -16,7 +16,7 @@ from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import cached, clear_thread_cache, thread_cached
from backend.util.cache import cached, clear_thread_cache, thread_cached
class TestThreadCached:
@@ -332,7 +332,7 @@ class TestCache:
"""Test basic sync caching functionality."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
def expensive_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
@@ -358,7 +358,7 @@ class TestCache:
"""Test basic async caching functionality."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
async def expensive_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
@@ -385,7 +385,7 @@ class TestCache:
call_count = 0
results = []
@cached()
@cached(ttl_seconds=300)
def slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -412,7 +412,7 @@ class TestCache:
"""Test that concurrent async calls don't cause thundering herd."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
async def slow_async_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -508,7 +508,7 @@ class TestCache:
"""Test cache clearing functionality."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -537,7 +537,7 @@ class TestCache:
"""Test cache clearing functionality with async function."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
async def async_clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -567,7 +567,7 @@ class TestCache:
"""Test that cached async functions return actual results, not coroutines."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
async def async_result_function(x: int) -> str:
nonlocal call_count
call_count += 1
@@ -593,7 +593,7 @@ class TestCache:
"""Test selective cache deletion functionality."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
def deletable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -636,7 +636,7 @@ class TestCache:
"""Test selective cache deletion functionality with async function."""
call_count = 0
@cached()
@cached(ttl_seconds=300)
async def async_deletable_function(x: int) -> int:
nonlocal call_count
call_count += 1
@@ -674,3 +674,450 @@ class TestCache:
# Try to delete non-existent entry
was_deleted = async_deletable_function.cache_delete(99)
assert was_deleted is False
class TestSharedCache:
"""Tests for shared_cache (Redis-backed) functionality."""
def test_sync_shared_cache_basic(self):
"""Test basic shared cache functionality with sync function."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def shared_sync_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
# Clear any existing cache
shared_sync_function.cache_clear()
# First call
result1 = shared_sync_function(10, 20)
assert result1 == 30
assert call_count == 1
# Second call - should use Redis cache
result2 = shared_sync_function(10, 20)
assert result2 == 30
assert call_count == 1
# Different args - should call function again
result3 = shared_sync_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_sync_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_basic(self):
"""Test basic shared cache functionality with async function."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def shared_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x + y
# Clear any existing cache
shared_async_function.cache_clear()
# First call
result1 = await shared_async_function(10, 20)
assert result1 == 30
assert call_count == 1
# Second call - should use Redis cache
result2 = await shared_async_function(10, 20)
assert result2 == 30
assert call_count == 1
# Different args - should call function again
result3 = await shared_async_function(15, 25)
assert result3 == 40
assert call_count == 2
# Cleanup
shared_async_function.cache_clear()
def test_shared_cache_ttl_refresh(self):
"""Test TTL refresh functionality with shared cache."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=True)
def ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
ttl_refresh_function.cache_clear()
# First call
result1 = ttl_refresh_function(3)
assert result1 == 30
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should refresh TTL and use cache
result2 = ttl_refresh_function(3)
assert result2 == 30
assert call_count == 1
# Wait another 1.5 seconds (total 2.5s from first call, 1.5s from second)
time.sleep(1.5)
# Third call - TTL should have been refreshed, so still cached
result3 = ttl_refresh_function(3)
assert result3 == 30
assert call_count == 1
# Wait 2.1 seconds - now it should expire
time.sleep(2.1)
# Fourth call - should call function again
result4 = ttl_refresh_function(3)
assert result4 == 30
assert call_count == 2
# Cleanup
ttl_refresh_function.cache_clear()
def test_shared_cache_without_ttl_refresh(self):
"""Test that TTL doesn't refresh when refresh_ttl_on_get=False."""
call_count = 0
@cached(ttl_seconds=2, shared_cache=True, refresh_ttl_on_get=False)
def no_ttl_refresh_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# Clear any existing cache
no_ttl_refresh_function.cache_clear()
# First call
result1 = no_ttl_refresh_function(4)
assert result1 == 40
assert call_count == 1
# Wait 1 second
time.sleep(1)
# Second call - should use cache but NOT refresh TTL
result2 = no_ttl_refresh_function(4)
assert result2 == 40
assert call_count == 1
# Wait another 1.1 seconds (total 2.1s from first call)
time.sleep(1.1)
# Third call - should have expired
result3 = no_ttl_refresh_function(4)
assert result3 == 40
assert call_count == 2
# Cleanup
no_ttl_refresh_function.cache_clear()
def test_shared_cache_complex_objects(self):
"""Test caching complex objects with shared cache (pickle serialization)."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def complex_object_function(x: int) -> dict:
nonlocal call_count
call_count += 1
return {
"number": x,
"squared": x**2,
"nested": {"list": [1, 2, x], "tuple": (x, x * 2)},
"string": f"value_{x}",
}
# Clear any existing cache
complex_object_function.cache_clear()
# First call
result1 = complex_object_function(5)
assert result1["number"] == 5
assert result1["squared"] == 25
assert result1["nested"]["list"] == [1, 2, 5]
assert call_count == 1
# Second call - should use cache
result2 = complex_object_function(5)
assert result2 == result1
assert call_count == 1
# Cleanup
complex_object_function.cache_clear()
def test_shared_cache_info(self):
"""Test cache_info for shared cache."""
@cached(ttl_seconds=30, shared_cache=True)
def info_shared_function(x: int) -> int:
return x * 2
# Clear any existing cache
info_shared_function.cache_clear()
# Check initial info
info = info_shared_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] is None # Redis manages size
assert info["ttl_seconds"] == 30
# Add some entries
info_shared_function(1)
info_shared_function(2)
info_shared_function(3)
info = info_shared_function.cache_info()
assert info["size"] == 3
# Cleanup
info_shared_function.cache_clear()
def test_shared_cache_delete(self):
"""Test selective deletion with shared cache."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def delete_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 3
# Clear any existing cache
delete_shared_function.cache_clear()
# Add entries
delete_shared_function(1)
delete_shared_function(2)
delete_shared_function(3)
assert call_count == 3
# Verify cached
delete_shared_function(1)
delete_shared_function(2)
assert call_count == 3
# Delete specific entry
was_deleted = delete_shared_function.cache_delete(2)
assert was_deleted is True
# Entry for x=2 should be gone
delete_shared_function(2)
assert call_count == 4
# Others should still be cached
delete_shared_function(1)
delete_shared_function(3)
assert call_count == 4
# Try to delete non-existent
was_deleted = delete_shared_function.cache_delete(99)
assert was_deleted is False
# Cleanup
delete_shared_function.cache_clear()
@pytest.mark.asyncio
async def test_async_shared_cache_thundering_herd(self):
"""Test that shared cache prevents thundering herd for async functions."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
async def shared_slow_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x * x
# Clear any existing cache
shared_slow_function.cache_clear()
# Launch multiple concurrent tasks
tasks = [shared_slow_function(8) for _ in range(10)]
results = await asyncio.gather(*tasks)
# All should return same result
assert all(r == 64 for r in results)
# Only one should have executed
assert call_count == 1
# Cleanup
shared_slow_function.cache_clear()
def test_shared_cache_clear_pattern(self):
"""Test pattern-based cache clearing (Redis feature)."""
@cached(ttl_seconds=30, shared_cache=True)
def pattern_function(category: str, item: int) -> str:
return f"{category}_{item}"
# Clear any existing cache
pattern_function.cache_clear()
# Add various entries
pattern_function("fruit", 1)
pattern_function("fruit", 2)
pattern_function("vegetable", 1)
pattern_function("vegetable", 2)
info = pattern_function.cache_info()
assert info["size"] == 4
# Note: Pattern clearing with wildcards requires specific Redis scan
# implementation. The current code clears by pattern but needs
# adjustment for partial matching. For now, test full clear.
pattern_function.cache_clear()
info = pattern_function.cache_info()
assert info["size"] == 0
def test_shared_vs_local_cache_isolation(self):
"""Test that shared and local caches are isolated."""
shared_count = 0
local_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def shared_function(x: int) -> int:
nonlocal shared_count
shared_count += 1
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_function(x: int) -> int:
nonlocal local_count
local_count += 1
return x * 2
# Clear caches
shared_function.cache_clear()
local_function.cache_clear()
# Call both with same args
shared_result = shared_function(5)
local_result = local_function(5)
assert shared_result == local_result == 10
assert shared_count == 1
assert local_count == 1
# Call again - both should use their respective caches
shared_function(5)
local_function(5)
assert shared_count == 1
assert local_count == 1
# Clear only shared cache
shared_function.cache_clear()
# Shared should recompute, local should still use cache
shared_function(5)
local_function(5)
assert shared_count == 2
assert local_count == 1
# Cleanup
shared_function.cache_clear()
local_function.cache_clear()
@pytest.mark.asyncio
async def test_shared_cache_concurrent_different_keys(self):
"""Test that concurrent calls with different keys work correctly."""
call_counts = {}
@cached(ttl_seconds=30, shared_cache=True)
async def multi_key_function(key: str) -> str:
if key not in call_counts:
call_counts[key] = 0
call_counts[key] += 1
await asyncio.sleep(0.05)
return f"result_{key}"
# Clear cache
multi_key_function.cache_clear()
# Launch concurrent tasks with different keys
keys = ["a", "b", "c", "d", "e"]
tasks = []
for key in keys:
# Multiple calls per key
tasks.extend([multi_key_function(key) for _ in range(3)])
results = await asyncio.gather(*tasks)
# Verify results
for i, key in enumerate(keys):
expected = f"result_{key}"
# Each key appears 3 times in results
key_results = results[i * 3 : (i + 1) * 3]
assert all(r == expected for r in key_results)
# Each key should only be computed once
for key in keys:
assert call_counts[key] == 1
# Cleanup
multi_key_function.cache_clear()
def test_shared_cache_performance_comparison(self):
"""Compare performance of shared vs local cache."""
import statistics
shared_times = []
local_times = []
@cached(ttl_seconds=30, shared_cache=True)
def shared_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
@cached(ttl_seconds=30, shared_cache=False)
def local_perf_function(x: int) -> int:
time.sleep(0.01) # Simulate work
return x * 2
# Clear caches
shared_perf_function.cache_clear()
local_perf_function.cache_clear()
# Warm up both caches
for i in range(5):
shared_perf_function(i)
local_perf_function(i)
# Measure cache hit times
for i in range(5):
# Shared cache hit
start = time.time()
shared_perf_function(i)
shared_times.append(time.time() - start)
# Local cache hit
start = time.time()
local_perf_function(i)
local_times.append(time.time() - start)
# Local cache should be faster (no Redis round-trip)
avg_shared = statistics.mean(shared_times)
avg_local = statistics.mean(local_times)
print(f"Avg shared cache hit time: {avg_shared:.6f}s")
print(f"Avg local cache hit time: {avg_local:.6f}s")
# Local should be significantly faster for cache hits
# Redis adds network latency even for cache hits
assert avg_local < avg_shared
# Cleanup
shared_perf_function.cache_clear()
local_perf_function.cache_clear()

View File

@@ -4,8 +4,7 @@ Centralized service client helpers with thread caching.
from typing import TYPE_CHECKING
from autogpt_libs.utils.cache import cached, thread_cached
from backend.util.cache import cached, thread_cached
from backend.util.settings import Settings
settings = Settings()
@@ -120,7 +119,7 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore":
# ============ Supabase Clients ============ #
@cached()
@cached(ttl_seconds=3600)
def get_supabase() -> "Client":
"""Get a process-cached synchronous Supabase client instance."""
from supabase import create_client
@@ -130,7 +129,7 @@ def get_supabase() -> "Client":
)
@cached()
@cached(ttl_seconds=3600)
async def get_async_supabase() -> "AClient":
"""Get a process-cached asynchronous Supabase client instance."""
from supabase import create_async_client

View File

@@ -5,12 +5,12 @@ from functools import wraps
from typing import Any, Awaitable, Callable, TypeVar
import ldclient
from autogpt_libs.utils.cache import cached
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
from typing_extensions import ParamSpec
from backend.util.cache import cached
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -35,6 +35,7 @@ class Flag(str, Enum):
AI_ACTIVITY_STATUS = "ai-agent-execution-summary"
BETA_BLOCKS = "beta-blocks"
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
def is_configured() -> bool:
@@ -62,9 +63,9 @@ def initialize_launchdarkly() -> None:
config = Config(sdk_key)
ldclient.set_config(config)
global _is_initialized
_is_initialized = True
if ldclient.get().is_initialized():
global _is_initialized
_is_initialized = True
logger.info("LaunchDarkly client initialized successfully")
else:
logger.error("LaunchDarkly client failed to initialize")
@@ -217,7 +218,8 @@ def feature_flag(
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
"LaunchDarkly not initialized, "
f"using default {flag_key}={repr(default)}"
)
is_enabled = default
else:
@@ -231,8 +233,9 @@ def feature_flag(
else:
# Log warning and use default for non-boolean values
logger.warning(
f"Feature flag {flag_key} returned non-boolean value: {flag_value} (type: {type(flag_value).__name__}). "
f"Using default={default}"
f"Feature flag {flag_key} returned non-boolean value: "
f"{repr(flag_value)} (type: {type(flag_value).__name__}). "
f"Using default value {repr(default)}"
)
is_enabled = default

View File

@@ -1,32 +1,21 @@
import json
import logging
import re
from typing import Any, Type, TypeGuard, TypeVar, overload
from typing import Any, Type, TypeVar, overload
import jsonschema
import orjson
from fastapi.encoders import jsonable_encoder
from fastapi.encoders import jsonable_encoder as to_dict
from prisma import Json
from pydantic import BaseModel
from .truncate import truncate
from .type import type_match
logger = logging.getLogger(__name__)
# Precompiled regex to remove PostgreSQL-incompatible control characters
# Removes \u0000-\u0008, \u000B-\u000C, \u000E-\u001F, \u007F (keeps tab \u0009, newline \u000A, carriage return \u000D)
POSTGRES_CONTROL_CHARS = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]")
# Comprehensive regex to remove all PostgreSQL-incompatible control character sequences in JSON
# Handles both Unicode escapes (\\u0000-\\u0008, \\u000B-\\u000C, \\u000E-\\u001F, \\u007F)
# and JSON single-char escapes (\\b, \\f) while preserving legitimate file paths
POSTGRES_JSON_ESCAPES = re.compile(
r"\\u000[0-8]|\\u000[bB]|\\u000[cC]|\\u00[0-1][0-9a-fA-F]|\\u007[fF]|(?<!\\)\\[bf](?!\\)"
)
def to_dict(data) -> dict:
if isinstance(data, BaseModel):
data = data.model_dump()
return jsonable_encoder(data)
def dumps(
data: Any, *args: Any, indent: int | None = None, option: int = 0, **kwargs: Any
@@ -116,38 +105,57 @@ def validate_with_jsonschema(
return str(e)
def is_list_of_basemodels(value: object) -> TypeGuard[list[BaseModel]]:
return isinstance(value, list) and all(
isinstance(item, BaseModel) for item in value
)
def _sanitize_string(value: str) -> str:
"""Remove PostgreSQL-incompatible control characters from string."""
return POSTGRES_CONTROL_CHARS.sub("", value)
def convert_pydantic_to_json(output_data: Any) -> Any:
if isinstance(output_data, BaseModel):
return output_data.model_dump()
if is_list_of_basemodels(output_data):
return [item.model_dump() for item in output_data]
return output_data
def sanitize_json(data: Any) -> Any:
try:
# Use two-pass approach for consistent string sanitization:
# 1. First convert to basic JSON-serializable types (handles Pydantic models)
# 2. Then sanitize strings in the result
basic_result = to_dict(data)
return to_dict(basic_result, custom_encoder={str: _sanitize_string})
except Exception as e:
# Log the failure and fall back to string representation
logger.error(
"SafeJson fallback to string representation due to serialization error: %s (%s). "
"Data type: %s, Data preview: %s",
type(e).__name__,
truncate(str(e), 200),
type(data).__name__,
truncate(str(data), 100),
)
# Ultimate fallback: convert to string representation and sanitize
return _sanitize_string(str(data))
def SafeJson(data: Any) -> Json:
class SafeJson(Json):
"""
Safely serialize data and return Prisma's Json type.
Sanitizes null bytes to prevent PostgreSQL 22P05 errors.
Sanitizes control characters to prevent PostgreSQL 22P05 errors.
This function:
1. Converts Pydantic models to dicts (recursively using to_dict)
2. Recursively removes PostgreSQL-incompatible control characters from strings
3. Returns a Prisma Json object safe for database storage
Uses to_dict (jsonable_encoder) with a custom encoder to handle both Pydantic
conversion and control character sanitization in a two-pass approach.
Args:
data: Input data to sanitize and convert to Json
Returns:
Prisma Json object with control characters removed
Examples:
>>> SafeJson({"text": "Hello\\x00World"}) # null char removed
>>> SafeJson({"path": "C:\\\\temp"}) # backslashes preserved
>>> SafeJson({"data": "Text\\\\u0000here"}) # literal backslash-u preserved
"""
if isinstance(data, BaseModel):
json_string = data.model_dump_json(
warnings="error",
exclude_none=True,
fallback=lambda v: None,
)
else:
json_string = dumps(data, default=lambda v: None)
# Remove PostgreSQL-incompatible control characters in JSON string
# Single comprehensive regex handles all control character sequences
sanitized_json = POSTGRES_JSON_ESCAPES.sub("", json_string)
# Remove any remaining raw control characters (fallback safety net)
sanitized_json = POSTGRES_CONTROL_CHARS.sub("", sanitized_json)
return Json(json.loads(sanitized_json))
def __init__(self, data: Any):
super().__init__(sanitize_json(data))

View File

@@ -8,10 +8,7 @@ settings = Settings()
def configure_logging():
import autogpt_libs.logging.config
if (
settings.config.behave_as == BehaveAs.LOCAL
or settings.config.app_env == AppEnvironment.LOCAL
):
if not is_structured_logging_enabled():
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
else:
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
@@ -20,6 +17,14 @@ def configure_logging():
logging.getLogger("httpx").setLevel(logging.WARNING)
def is_structured_logging_enabled() -> bool:
"""Check if structured logging (cloud logging) is enabled."""
return not (
settings.config.behave_as == BehaveAs.LOCAL
or settings.config.app_env == AppEnvironment.LOCAL
)
class TruncatedLogger:
def __init__(
self,

View File

@@ -3,15 +3,17 @@ from enum import Enum
import sentry_sdk
from pydantic import SecretStr
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.integrations.anthropic import AnthropicIntegration
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sentry_sdk.integrations.launchdarkly import LaunchDarklyIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from backend.util.feature_flag import get_client, is_configured
from backend.util import feature_flag
from backend.util.settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
class DiscordChannel(str, Enum):
@@ -22,8 +24,11 @@ class DiscordChannel(str, Enum):
def sentry_init():
sentry_dsn = settings.secrets.sentry_dsn
integrations = []
if is_configured():
integrations.append(LaunchDarklyIntegration(get_client()))
if feature_flag.is_configured():
try:
integrations.append(LaunchDarklyIntegration(feature_flag.get_client()))
except DidNotEnable as e:
logger.error(f"Error enabling LaunchDarklyIntegration for Sentry: {e}")
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=1.0,

View File

@@ -8,18 +8,9 @@ from typing import Optional
from backend.util.logging import configure_logging
from backend.util.metrics import sentry_init
from backend.util.settings import set_service_name
logger = logging.getLogger(__name__)
_SERVICE_NAME = "MainProcess"
def get_service_name():
return _SERVICE_NAME
def set_service_name(name: str):
global _SERVICE_NAME
_SERVICE_NAME = name
class AppProcess(ABC):
@@ -28,7 +19,8 @@ class AppProcess(ABC):
"""
process: Optional[Process] = None
cleaned_up = False
_shutting_down: bool = False
_cleaned_up: bool = False
if "forkserver" in get_all_start_methods():
set_start_method("forkserver", force=True)
@@ -52,7 +44,6 @@ class AppProcess(ABC):
def service_name(self) -> str:
return self.__class__.__name__
@abstractmethod
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,
@@ -74,7 +65,8 @@ class AppProcess(ABC):
self.run()
except BaseException as e:
logger.warning(
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
f"[{self.service_name}] 🛑 Terminating because of {type(e).__name__}: {e}", # noqa
exc_info=e if not isinstance(e, SystemExit) else None,
)
# Send error to Sentry before cleanup
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
@@ -85,8 +77,12 @@ class AppProcess(ABC):
except Exception:
pass # Silently ignore if Sentry isn't available
finally:
self.cleanup()
logger.info(f"[{self.service_name}] Terminated.")
if not self._cleaned_up:
self._cleaned_up = True
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
self.cleanup()
logger.info(f"[{self.service_name}] ✅ Cleanup done")
logger.info(f"[{self.service_name}] 🛑 Terminated")
@staticmethod
def llprint(message: str):
@@ -97,8 +93,8 @@ class AppProcess(ABC):
os.write(sys.stdout.fileno(), (message + "\n").encode())
def _self_terminate(self, signum: int, frame):
if not self.cleaned_up:
self.cleaned_up = True
if not self._shutting_down:
self._shutting_down = True
sys.exit(0)
else:
self.llprint(

View File

@@ -13,7 +13,7 @@ import idna
from aiohttp import FormData, abc
from tenacity import retry, retry_if_result, wait_exponential_jitter
from backend.util.json import json
from backend.util.json import loads
# Retry status codes for which we will automatically retry the request
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
@@ -175,10 +175,15 @@ async def validate_url(
f"for hostname {ascii_hostname} is not allowed."
)
# Reconstruct the netloc with IDNA-encoded hostname and preserve port
netloc = ascii_hostname
if parsed.port:
netloc = f"{ascii_hostname}:{parsed.port}"
return (
URL(
parsed.scheme,
ascii_hostname,
netloc,
quote(parsed.path, safe="/%:@"),
parsed.params,
parsed.query,
@@ -259,7 +264,7 @@ class Response:
"""
Parse the body as JSON and return the resulting Python object.
"""
return json.loads(
return loads(
self.content.decode(encoding or "utf-8", errors="replace"), **kwargs
)

View File

@@ -13,7 +13,7 @@ from tenacity import (
wait_exponential_jitter,
)
from backend.util.process import get_service_name
from backend.util.settings import get_service_name
logger = logging.getLogger(__name__)

View File

@@ -4,9 +4,12 @@ import concurrent.futures
import inspect
import logging
import os
import signal
import sys
import threading
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from functools import update_wrapper
from typing import (
Any,
@@ -31,9 +34,9 @@ import backend.util.exceptions as exceptions
from backend.monitoring.instrumentation import instrument_fastapi
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
from backend.util.process import AppProcess
from backend.util.retry import conn_retry, create_retry_decorator
from backend.util.settings import Config
from backend.util.settings import Config, get_service_name
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -111,14 +114,44 @@ class BaseAppService(AppProcess, ABC):
return target_host
def run_service(self) -> None:
while True:
time.sleep(10)
# HACK: run the main event loop outside the main thread to disable Uvicorn's
# internal signal handlers, since there is no config option for this :(
shared_asyncio_thread = threading.Thread(
target=self._run_shared_event_loop,
daemon=True,
name=f"{self.service_name}-shared-event-loop",
)
shared_asyncio_thread.start()
shared_asyncio_thread.join()
def _run_shared_event_loop(self) -> None:
try:
self.shared_event_loop.run_forever()
finally:
logger.info(f"[{self.service_name}] 🛑 Shared event loop stopped")
self.shared_event_loop.close() # ensure held resources are released
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
self.shared_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.shared_event_loop)
def cleanup(self):
"""
**💡 Overriding `AppService.lifespan` may be a more convenient option.**
Implement this method on a subclass to do post-execution cleanup,
e.g. disconnecting from a database or terminating child processes.
**Note:** if you override this method in a subclass, it must call
`super().cleanup()` *at the end*!
"""
# Stop the shared event loop to allow resource clean-up
self.shared_event_loop.call_soon_threadsafe(self.shared_event_loop.stop)
super().cleanup()
class RemoteCallError(BaseModel):
@@ -179,6 +212,7 @@ EXCEPTION_MAPPING = {
class AppService(BaseAppService, ABC):
fastapi_app: FastAPI
http_server: uvicorn.Server | None = None
log_level: str = "info"
def set_log_level(self, log_level: str):
@@ -190,11 +224,10 @@ class AppService(BaseAppService, ABC):
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: Request, exc: Exception):
if log_error:
if status_code == 500:
log = logger.exception
else:
log = logger.error
log(f"{request.method} {request.url.path} failed: {exc}")
logger.error(
f"{request.method} {request.url.path} failed: {exc}",
exc_info=exc if status_code == 500 else None,
)
return responses.JSONResponse(
status_code=status_code,
content=RemoteCallError(
@@ -256,13 +289,13 @@ class AppService(BaseAppService, ABC):
return sync_endpoint
@conn_retry("FastAPI server", "Starting FastAPI server")
@conn_retry("FastAPI server", "Running FastAPI server")
def __start_fastapi(self):
logger.info(
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
)
server = uvicorn.Server(
self.http_server = uvicorn.Server(
uvicorn.Config(
self.fastapi_app,
host=api_host,
@@ -271,18 +304,76 @@ class AppService(BaseAppService, ABC):
log_level=self.log_level,
)
)
self.shared_event_loop.run_until_complete(server.serve())
self.run_and_wait(self.http_server.serve())
# Perform clean-up when the server exits
if not self._cleaned_up:
self._cleaned_up = True
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
self.cleanup()
logger.info(f"[{self.service_name}] ✅ Cleanup done")
def _self_terminate(self, signum: int, frame):
"""Pass SIGTERM to Uvicorn so it can shut down gracefully"""
signame = signal.Signals(signum).name
if not self._shutting_down:
self._shutting_down = True
if self.http_server:
logger.info(
f"[{self.service_name}] 🛑 Received {signame} ({signum}) - "
"Entering RPC server graceful shutdown"
)
self.http_server.handle_exit(signum, frame) # stop accepting requests
# NOTE: Actually stopping the process is triggered by:
# 1. The call to self.cleanup() at the end of __start_fastapi() 👆🏼
# 2. BaseAppService.cleanup() stopping the shared event loop
else:
logger.warning(
f"[{self.service_name}] {signame} received before HTTP server init."
" Terminating..."
)
sys.exit(0)
else:
# Expedite shutdown on second SIGTERM
logger.info(
f"[{self.service_name}] 🛑🛑 Received {signame} ({signum}), "
"but shutdown is already underway. Terminating..."
)
sys.exit(0)
@asynccontextmanager
async def lifespan(self, app: FastAPI):
"""
The FastAPI/Uvicorn server's lifespan manager, used for setup and shutdown.
You can extend and use this in a subclass like:
```
@asynccontextmanager
async def lifespan(self, app: FastAPI):
async with super().lifespan(app):
await db.connect()
yield
await db.disconnect()
```
"""
# Startup - this runs before Uvicorn starts accepting connections
yield
# Shutdown - this runs when FastAPI/Uvicorn shuts down
logger.info(f"[{self.service_name}] ✅ FastAPI has finished")
async def health_check(self) -> str:
"""
A method to check the health of the process.
"""
"""A method to check the health of the process."""
return "OK"
def run(self):
sentry_init()
super().run()
self.fastapi_app = FastAPI()
self.fastapi_app = FastAPI(lifespan=self.lifespan)
# Add Prometheus instrumentation to all services
try:
@@ -325,7 +416,11 @@ class AppService(BaseAppService, ABC):
)
# Start the FastAPI server in a separate thread.
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
api_thread = threading.Thread(
target=self.__start_fastapi,
daemon=True,
name=f"{self.service_name}-http-server",
)
api_thread.start()
# Run the main service loop (blocking).

View File

@@ -1,3 +1,5 @@
import asyncio
import contextlib
import time
from functools import cached_property
from unittest.mock import Mock
@@ -18,20 +20,11 @@ from backend.util.service import (
TEST_SERVICE_PORT = 8765
def wait_for_service_ready(service_client_type, timeout_seconds=30):
"""Helper method to wait for a service to be ready using health check with retry."""
client = get_service_client(service_client_type, request_retry=True)
client.health_check() # This will retry until service is ready
class ServiceTest(AppService):
def __init__(self):
super().__init__()
self.fail_count = 0
def cleanup(self):
pass
@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT
@@ -41,10 +34,17 @@ class ServiceTest(AppService):
result = super().__enter__()
# Wait for the service to be ready
wait_for_service_ready(ServiceTestClient)
self.wait_until_ready()
return result
def wait_until_ready(self, timeout_seconds: int = 5):
"""Helper method to wait for a service to be ready using health check with retry."""
client = get_service_client(
ServiceTestClient, call_timeout=timeout_seconds, request_retry=True
)
client.health_check() # This will retry until service is ready\
@expose
def add(self, a: int, b: int) -> int:
return a + b
@@ -490,3 +490,167 @@ class TestHTTPErrorRetryBehavior:
)
assert exc_info.value.status_code == status_code
class TestGracefulShutdownService(AppService):
"""Test service with slow endpoints for testing graceful shutdown"""
@classmethod
def get_port(cls) -> int:
return 18999 # Use a specific test port
def __init__(self):
super().__init__()
self.request_log = []
self.cleanup_called = False
self.cleanup_completed = False
@expose
async def slow_endpoint(self, duration: int = 5) -> dict:
"""Endpoint that takes time to complete"""
start_time = time.time()
self.request_log.append(f"slow_endpoint started at {start_time}")
await asyncio.sleep(duration)
end_time = time.time()
result = {
"message": "completed",
"duration": end_time - start_time,
"start_time": start_time,
"end_time": end_time,
}
self.request_log.append(f"slow_endpoint completed at {end_time}")
return result
@expose
def fast_endpoint(self) -> dict:
"""Fast endpoint for testing rejection during shutdown"""
timestamp = time.time()
self.request_log.append(f"fast_endpoint called at {timestamp}")
return {"message": "fast", "timestamp": timestamp}
def cleanup(self):
"""Override cleanup to track when it's called"""
self.cleanup_called = True
self.request_log.append(f"cleanup started at {time.time()}")
# Call parent cleanup
super().cleanup()
self.cleanup_completed = True
self.request_log.append(f"cleanup completed at {time.time()}")
@pytest.fixture(scope="function")
async def test_service():
"""Run the test service in a separate process"""
service = TestGracefulShutdownService()
service.start(background=True)
base_url = f"http://localhost:{service.get_port()}"
await wait_until_service_ready(base_url)
yield service, base_url
service.stop()
async def wait_until_service_ready(base_url: str, timeout: float = 10):
start_time = time.time()
while time.time() - start_time <= timeout:
async with httpx.AsyncClient(timeout=5) as client:
with contextlib.suppress(httpx.ConnectError):
response = await client.get(f"{base_url}/health_check", timeout=5)
if response.status_code == 200 and response.json() == "OK":
return
await asyncio.sleep(0.5)
raise RuntimeError(f"Service at {base_url} not available after {timeout} seconds")
async def send_slow_request(base_url: str) -> dict:
"""Send a slow request and return the result"""
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(f"{base_url}/slow_endpoint", json={"duration": 5})
assert response.status_code == 200
return response.json()
@pytest.mark.asyncio
async def test_graceful_shutdown(test_service):
"""Test that AppService handles graceful shutdown correctly"""
service, test_service_url = test_service
# Start a slow request that should complete even after shutdown
slow_task = asyncio.create_task(send_slow_request(test_service_url))
# Give the slow request time to start
await asyncio.sleep(1)
# Send SIGTERM to the service process
shutdown_start_time = time.time()
service.process.terminate() # This sends SIGTERM
# Wait a moment for shutdown to start
await asyncio.sleep(0.5)
# Try to send a new request - should be rejected or connection refused
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.post(f"{test_service_url}/fast_endpoint", json={})
# Should get 503 Service Unavailable during shutdown
assert response.status_code == 503
assert "shutting down" in response.json()["detail"].lower()
except httpx.ConnectError:
# Connection refused is also acceptable - server stopped accepting
pass
# The slow request should still complete successfully
slow_result = await slow_task
assert slow_result["message"] == "completed"
assert 4.9 < slow_result["duration"] < 5.5 # Should have taken ~5 seconds
# Wait for the service to fully shut down
service.process.join(timeout=15)
shutdown_end_time = time.time()
# Verify the service actually terminated
assert not service.process.is_alive()
# Verify shutdown took reasonable time (slow request - 1s + cleanup)
shutdown_duration = shutdown_end_time - shutdown_start_time
assert 4 <= shutdown_duration <= 6 # ~5s request - 1s + buffer
print(f"Shutdown took {shutdown_duration:.2f} seconds")
print(f"Slow request completed in: {slow_result['duration']:.2f} seconds")
@pytest.mark.asyncio
async def test_health_check_during_shutdown(test_service):
"""Test that health checks behave correctly during shutdown"""
service, test_service_url = test_service
# Health check should pass initially
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{test_service_url}/health_check")
assert response.status_code == 200
# Send SIGTERM
service.process.terminate()
# Wait for shutdown to begin
await asyncio.sleep(1)
# Health check should now fail or connection should be refused
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{test_service_url}/health_check")
# Could either get 503, 500 (unhealthy), or connection error
assert response.status_code in [500, 503]
except (httpx.ConnectError, httpx.ConnectTimeout):
# Connection refused/timeout is also acceptable
pass

View File

@@ -15,6 +15,17 @@ from backend.util.data import get_data_path
T = TypeVar("T", bound=BaseSettings)
_SERVICE_NAME = "MainProcess"
def get_service_name():
return _SERVICE_NAME
def set_service_name(name: str):
global _SERVICE_NAME
_SERVICE_NAME = name
class AppEnvironment(str, Enum):
LOCAL = "local"
@@ -254,6 +265,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="localhost",
description="The host for the RabbitMQ server",
)
rabbitmq_port: int = Field(
default=5672,
description="The port for the RabbitMQ server",
@@ -264,6 +276,21 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The vhost for the RabbitMQ server",
)
redis_host: str = Field(
default="localhost",
description="The host for the Redis server",
)
redis_port: int = Field(
default=6379,
description="The port for the Redis server",
)
redis_password: str = Field(
default="",
description="The password for the Redis server (empty string if no password)",
)
postmark_sender_email: str = Field(
default="invalid@invalid.com",
description="The email address to use for sending emails",

View File

@@ -411,3 +411,346 @@ class TestSafeJson:
assert "C:\\temp\\file" in str(file_path_with_null)
assert ".txt" in str(file_path_with_null)
assert "\x00" not in str(file_path_with_null) # Null removed from path
def test_invalid_escape_error_prevention(self):
"""Test that SafeJson prevents 'Invalid \\escape' errors that occurred in upsert_execution_output."""
# This reproduces the exact scenario that was causing the error:
# POST /upsert_execution_output failed: Invalid \escape: line 1 column 36404 (char 36403)
# Create data with various problematic escape sequences that could cause JSON parsing errors
problematic_output_data = {
"web_content": "Article text\x00with null\x01and control\x08chars\x0C\x1F\x7F",
"file_path": "C:\\Users\\test\\file\x00.txt",
"json_like_string": '{"text": "data\x00\x08\x1F"}',
"escaped_sequences": "Text with \\u0000 and \\u0008 sequences",
"mixed_content": "Normal text\tproperly\nformatted\rwith\x00invalid\x08chars\x1Fmixed",
"large_text": "A" * 35000
+ "\x00\x08\x1F"
+ "B" * 5000, # Large text like in the error
}
# This should not raise any JSON parsing errors
result = SafeJson(problematic_output_data)
assert isinstance(result, Json)
# Verify the result is a valid Json object that can be safely stored in PostgreSQL
result_data = cast(dict[str, Any], result.data)
assert isinstance(result_data, dict)
# Verify problematic characters are removed but safe content preserved
web_content = result_data.get("web_content", "")
file_path = result_data.get("file_path", "")
large_text = result_data.get("large_text", "")
# Check that control characters are removed
assert "\x00" not in str(web_content)
assert "\x01" not in str(web_content)
assert "\x08" not in str(web_content)
assert "\x0C" not in str(web_content)
assert "\x1F" not in str(web_content)
assert "\x7F" not in str(web_content)
# Check that legitimate content is preserved
assert "Article text" in str(web_content)
assert "with null" in str(web_content)
assert "and control" in str(web_content)
assert "chars" in str(web_content)
# Check file path handling
assert "C:\\Users\\test\\file" in str(file_path)
assert ".txt" in str(file_path)
assert "\x00" not in str(file_path)
# Check large text handling (the scenario from the error at char 36403)
assert len(str(large_text)) > 35000 # Content preserved
assert "A" * 1000 in str(large_text) # A's preserved
assert "B" * 1000 in str(large_text) # B's preserved
assert "\x00" not in str(large_text) # Control chars removed
assert "\x08" not in str(large_text)
assert "\x1F" not in str(large_text)
# Most importantly: ensure the result can be JSON-serialized without errors
# This would have failed with the old approach
import json
json_string = json.dumps(result.data) # Should not raise "Invalid \escape"
assert len(json_string) > 0
# And can be parsed back
parsed_back = json.loads(json_string)
assert isinstance(parsed_back, dict)
def test_dict_containing_pydantic_models(self):
"""Test that dicts containing Pydantic models are properly serialized."""
# This reproduces the bug from PR #11187 where credential_inputs failed
model1 = SamplePydanticModel(name="Alice", age=30)
model2 = SamplePydanticModel(name="Bob", age=25)
data = {
"user1": model1,
"user2": model2,
"regular_data": "test",
}
result = SafeJson(data)
assert isinstance(result, Json)
# Verify it can be JSON serialized (this was the bug)
import json
json_string = json.dumps(result.data)
assert "Alice" in json_string
assert "Bob" in json_string
def test_nested_pydantic_in_dict(self):
"""Test deeply nested Pydantic models in dicts."""
inner_model = SamplePydanticModel(name="Inner", age=20)
middle_model = SamplePydanticModel(
name="Middle", age=30, metadata={"inner": inner_model}
)
data = {
"level1": {
"level2": {
"model": middle_model,
"other": "data",
}
}
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "Middle" in json_string
assert "Inner" in json_string
def test_list_containing_pydantic_models_in_dict(self):
"""Test list of Pydantic models inside a dict."""
models = [SamplePydanticModel(name=f"User{i}", age=20 + i) for i in range(5)]
data = {
"users": models,
"count": len(models),
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "User0" in json_string
assert "User4" in json_string
def test_credentials_meta_input_scenario(self):
"""Test the exact scenario from create_graph_execution that was failing."""
# Simulate CredentialsMetaInput structure
class MockCredentialsMetaInput(BaseModel):
id: str
title: Optional[str] = None
provider: str
type: str
cred_input = MockCredentialsMetaInput(
id="test-123", title="Test Credentials", provider="github", type="oauth2"
)
# This is how credential_inputs is structured in create_graph_execution
credential_inputs = {"github_creds": cred_input}
# This should work without TypeError
result = SafeJson(credential_inputs)
assert isinstance(result, Json)
# Verify it can be JSON serialized
import json
json_string = json.dumps(result.data)
assert "test-123" in json_string
assert "github" in json_string
assert "oauth2" in json_string
def test_mixed_pydantic_and_primitives(self):
"""Test complex mix of Pydantic models and primitive types."""
model = SamplePydanticModel(name="Test", age=25)
data = {
"models": [model, {"plain": "dict"}, "string", 123],
"nested": {
"model": model,
"list": [1, 2, model, 4],
"plain": "text",
},
"plain_list": [1, 2, 3],
}
result = SafeJson(data)
assert isinstance(result, Json)
import json
json_string = json.dumps(result.data)
assert "Test" in json_string
assert "plain" in json_string
def test_pydantic_model_with_control_chars_in_dict(self):
"""Test Pydantic model with control chars when nested in dict."""
model = SamplePydanticModel(
name="Test\x00User", # Has null byte
age=30,
metadata={"info": "data\x08with\x0Ccontrols"},
)
data = {"credential": model}
result = SafeJson(data)
assert isinstance(result, Json)
# Verify control characters are removed
import json
json_string = json.dumps(result.data)
assert "\x00" not in json_string
assert "\x08" not in json_string
assert "\x0C" not in json_string
assert "TestUser" in json_string # Name preserved minus null byte
def test_deeply_nested_pydantic_models_control_char_sanitization(self):
"""Test that control characters are sanitized in deeply nested Pydantic models."""
# Create nested Pydantic models with control characters at different levels
class InnerModel(BaseModel):
deep_string: str
value: int = 42
metadata: dict = {}
class MiddleModel(BaseModel):
middle_string: str
inner: InnerModel
data: str
class OuterModel(BaseModel):
outer_string: str
middle: MiddleModel
# Create test data with control characters at every nesting level
inner = InnerModel(
deep_string="Deepest\x00Level\x08Control\x0CChars", # Multiple control chars at deepest level
metadata={
"nested_key": "Nested\x1FValue\x7FDelete"
}, # Control chars in nested dict
)
middle = MiddleModel(
middle_string="Middle\x01StartOfHeading\x1FUnitSeparator",
inner=inner,
data="Some\x0BVerticalTab\x0EShiftOut",
)
outer = OuterModel(outer_string="Outer\x00Null\x07Bell", middle=middle)
# Wrap in a dict with additional control characters
data = {
"top_level": "Top\x00Level\x08Backspace",
"nested_model": outer,
"list_with_strings": [
"List\x00Item1",
"List\x0CItem2\x1F",
{"dict_in_list": "Dict\x08Value"},
],
}
# Process with SafeJson
result = SafeJson(data)
assert isinstance(result, Json)
# Verify all control characters are removed at every level
import json
json_string = json.dumps(result.data)
# Check that NO control characters remain anywhere
control_chars = [
"\x00",
"\x01",
"\x02",
"\x03",
"\x04",
"\x05",
"\x06",
"\x07",
"\x08",
"\x0B",
"\x0C",
"\x0E",
"\x0F",
"\x10",
"\x11",
"\x12",
"\x13",
"\x14",
"\x15",
"\x16",
"\x17",
"\x18",
"\x19",
"\x1A",
"\x1B",
"\x1C",
"\x1D",
"\x1E",
"\x1F",
"\x7F",
]
for char in control_chars:
assert (
char not in json_string
), f"Control character {repr(char)} found in result"
# Verify specific sanitized content is present (control chars removed but text preserved)
result_data = cast(dict[str, Any], result.data)
# Top level
assert "TopLevelBackspace" in json_string
# Outer model level
assert "OuterNullBell" in json_string
# Middle model level
assert "MiddleStartOfHeadingUnitSeparator" in json_string
assert "SomeVerticalTabShiftOut" in json_string
# Inner model level (deepest nesting)
assert "DeepestLevelControlChars" in json_string
# Nested dict in model
assert "NestedValueDelete" in json_string
# List items
assert "ListItem1" in json_string
assert "ListItem2" in json_string
assert "DictValue" in json_string
# Verify structure is preserved (not just converted to string)
assert isinstance(result_data, dict)
assert isinstance(result_data["nested_model"], dict)
assert isinstance(result_data["nested_model"]["middle"], dict)
assert isinstance(result_data["nested_model"]["middle"]["inner"], dict)
assert isinstance(result_data["list_with_strings"], list)
# Verify specific deep values are accessible and sanitized
nested_model = cast(dict[str, Any], result_data["nested_model"])
middle = cast(dict[str, Any], nested_model["middle"])
inner = cast(dict[str, Any], middle["inner"])
deep_string = inner["deep_string"]
assert deep_string == "DeepestLevelControlChars"
metadata = cast(dict[str, Any], inner["metadata"])
nested_metadata = metadata["nested_key"]
assert nested_metadata == "NestedValueDelete"

View File

@@ -3,6 +3,7 @@ import logging
import bleach
from bleach.css_sanitizer import CSSSanitizer
from jinja2 import BaseLoader
from jinja2.exceptions import TemplateError
from jinja2.sandbox import SandboxedEnvironment
from markupsafe import Markup
@@ -101,8 +102,11 @@ class TextFormatter:
def format_string(self, template_str: str, values=None, **kwargs) -> str:
"""Regular template rendering with escaping"""
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)
try:
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)
except TemplateError as e:
raise ValueError(e) from e
def format_email(
self,

View File

@@ -0,0 +1,16 @@
-- Create UserBalance table for atomic credit operations
-- This replaces the need for User.balance column and provides better separation of concerns
-- UserBalance records are automatically created by the application when users interact with the credit system
-- CreateTable (only if it doesn't exist)
CREATE TABLE IF NOT EXISTS "UserBalance" (
"userId" TEXT NOT NULL,
"balance" INTEGER NOT NULL DEFAULT 0,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "UserBalance_pkey" PRIMARY KEY ("userId"),
CONSTRAINT "UserBalance_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE
);
-- CreateIndex (only if it doesn't exist)
CREATE INDEX IF NOT EXISTS "UserBalance_userId_idx" ON "UserBalance"("userId");

View File

@@ -0,0 +1,100 @@
-- AlterTable
ALTER TABLE "StoreListingVersion" ADD COLUMN "search" tsvector DEFAULT ''::tsvector;
-- Add trigger to update the search column with the tsvector of the agent
-- Function to be invoked by trigger
-- Drop the trigger first
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
-- Drop the function completely
DROP FUNCTION IF EXISTS update_tsvector_column();
-- Now recreate it fresh
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
BEGIN
NEW.search := to_tsvector('english',
COALESCE(NEW.name, '') || ' ' ||
COALESCE(NEW.description, '') || ' ' ||
COALESCE(NEW."subHeading", '')
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER SET search_path = platform, pg_temp;
-- Recreate the trigger
CREATE TRIGGER "update_tsvector"
BEFORE INSERT OR UPDATE ON "StoreListingVersion"
FOR EACH ROW
EXECUTE FUNCTION update_tsvector_column();
UPDATE "StoreListingVersion"
SET search = to_tsvector('english',
COALESCE(name, '') || ' ' ||
COALESCE(description, '') || ' ' ||
COALESCE("subHeading", '')
)
WHERE search IS NULL;
-- Drop and recreate the StoreAgent view with isAvailable field
DROP VIEW IF EXISTS "StoreAgent";
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH latest_versions AS (
SELECT
"storeListingId",
MAX(version) AS max_version
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username, -- Allow NULL for malformed sub-agents
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
slv.search,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding",
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
FROM "StoreListing" sl
JOIN latest_versions lv
ON sl.id = lv."storeListingId"
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = lv."storeListingId"
AND slv.version = lv.max_version
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
COMMIT;

View File

@@ -0,0 +1,21 @@
-- Migrate Claude 3.5 models to Claude 4.5 models
-- This updates all AgentNode blocks that use deprecated Claude 3.5 models to the new 4.5 models
-- See: https://docs.anthropic.com/en/docs/about-claude/models/legacy-model-guide
-- Update Claude 3.5 Sonnet to Claude 4.5 Sonnet
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"claude-sonnet-4-5-20250929"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'claude-3-5-sonnet-latest';
-- Update Claude 3.5 Haiku to Claude 4.5 Haiku
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"claude-haiku-4-5-20251001"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'claude-3-5-haiku-latest';

View File

@@ -5,10 +5,11 @@ datasource db {
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views"]
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views", "fullTextSearch"]
partial_type_generator = "backend/data/partial_types.py"
}
// User model to mirror Auth provider users
@@ -45,6 +46,7 @@ model User {
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
CreditTransactions CreditTransaction[]
UserBalance UserBalance?
AgentPresets AgentPreset[]
LibraryAgents LibraryAgent[]
@@ -663,6 +665,7 @@ view StoreAgent {
sub_heading String
description String
categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int
rating Float
versions String[]
@@ -746,7 +749,7 @@ model StoreListing {
slug String
// Allow this agent to be used during onboarding
useForOnboarding Boolean @default(false)
useForOnboarding Boolean @default(false)
// The currently active version that should be shown to users
activeVersionId String? @unique
@@ -797,6 +800,8 @@ model StoreListingVersion {
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
// Version workflow state
submissionStatus SubmissionStatus @default(DRAFT)
submittedAt DateTime?
@@ -887,6 +892,16 @@ model APIKey {
@@index([userId, status])
}
model UserBalance {
userId String @id
balance Int @default(0)
updatedAt DateTime @updatedAt
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@index([userId])
}
enum APIKeyStatus {
ACTIVE
REVOKED

View File

@@ -1,5 +1,5 @@
{
"email": "test@example.com",
"id": "test-user-id",
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"name": "Test User"
}

View File

@@ -28,6 +28,6 @@
"recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "test-user-id",
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1
}

View File

@@ -26,7 +26,7 @@
"recommended_schedule_cron": null,
"sub_graphs": [],
"trigger_setup_info": null,
"user_id": "test-user-id",
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"version": 1
}
]

View File

@@ -0,0 +1,140 @@
from unittest.mock import Mock, patch
import pytest
from youtube_transcript_api._errors import NoTranscriptFound
from youtube_transcript_api._transcripts import FetchedTranscript, Transcript
from backend.blocks.youtube import TranscribeYoutubeVideoBlock
class TestTranscribeYoutubeVideoBlock:
"""Test cases for TranscribeYoutubeVideoBlock language fallback functionality."""
def setup_method(self):
"""Set up test fixtures."""
self.youtube_block = TranscribeYoutubeVideoBlock()
def test_extract_video_id_standard_url(self):
"""Test extracting video ID from standard YouTube URL."""
url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dQw4w9WgXcQ"
def test_extract_video_id_short_url(self):
"""Test extracting video ID from shortened youtu.be URL."""
url = "https://youtu.be/dQw4w9WgXcQ"
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dQw4w9WgXcQ"
def test_extract_video_id_embed_url(self):
"""Test extracting video ID from embed URL."""
url = "https://www.youtube.com/embed/dQw4w9WgXcQ"
video_id = self.youtube_block.extract_video_id(url)
assert video_id == "dQw4w9WgXcQ"
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
def test_get_transcript_english_available(self, mock_api_class):
"""Test getting transcript when English is available."""
# Setup mock
mock_api = Mock()
mock_api_class.return_value = mock_api
mock_transcript = Mock(spec=FetchedTranscript)
mock_api.fetch.return_value = mock_transcript
# Execute
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
# Assert
assert result == mock_transcript
mock_api.fetch.assert_called_once_with(video_id="test_video_id")
mock_api.list.assert_not_called()
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
def test_get_transcript_fallback_to_first_available(self, mock_api_class):
"""Test fallback to first available language when English is not available."""
# Setup mock
mock_api = Mock()
mock_api_class.return_value = mock_api
# Create mock transcript list with Hungarian transcript
mock_transcript_list = Mock()
mock_transcript_hu = Mock(spec=Transcript)
mock_fetched_transcript = Mock(spec=FetchedTranscript)
mock_transcript_hu.fetch.return_value = mock_fetched_transcript
# Set up the transcript list to have manually created transcripts empty
# and generated transcripts with Hungarian
mock_transcript_list._manually_created_transcripts = {}
mock_transcript_list._generated_transcripts = {"hu": mock_transcript_hu}
# Mock API to raise NoTranscriptFound for English, then return list
mock_api.fetch.side_effect = NoTranscriptFound(
"test_video_id", ("en",), mock_transcript_list
)
mock_api.list.return_value = mock_transcript_list
# Execute
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
# Assert
assert result == mock_fetched_transcript
mock_api.fetch.assert_called_once_with(video_id="test_video_id")
mock_api.list.assert_called_once_with("test_video_id")
mock_transcript_hu.fetch.assert_called_once()
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
def test_get_transcript_prefers_manually_created(self, mock_api_class):
"""Test that manually created transcripts are preferred over generated ones."""
# Setup mock
mock_api = Mock()
mock_api_class.return_value = mock_api
# Create mock transcript list with both manual and generated transcripts
mock_transcript_list = Mock()
mock_transcript_manual = Mock(spec=Transcript)
mock_transcript_generated = Mock(spec=Transcript)
mock_fetched_manual = Mock(spec=FetchedTranscript)
mock_transcript_manual.fetch.return_value = mock_fetched_manual
# Set up the transcript list
mock_transcript_list._manually_created_transcripts = {
"es": mock_transcript_manual
}
mock_transcript_list._generated_transcripts = {"hu": mock_transcript_generated}
# Mock API to raise NoTranscriptFound for English
mock_api.fetch.side_effect = NoTranscriptFound(
"test_video_id", ("en",), mock_transcript_list
)
mock_api.list.return_value = mock_transcript_list
# Execute
result = TranscribeYoutubeVideoBlock.get_transcript("test_video_id")
# Assert - should use manually created transcript first
assert result == mock_fetched_manual
mock_transcript_manual.fetch.assert_called_once()
mock_transcript_generated.fetch.assert_not_called()
@patch("backend.blocks.youtube.YouTubeTranscriptApi")
def test_get_transcript_no_transcripts_available(self, mock_api_class):
"""Test that exception is re-raised when no transcripts are available at all."""
# Setup mock
mock_api = Mock()
mock_api_class.return_value = mock_api
# Create mock transcript list with no transcripts
mock_transcript_list = Mock()
mock_transcript_list._manually_created_transcripts = {}
mock_transcript_list._generated_transcripts = {}
# Mock API to raise NoTranscriptFound
original_exception = NoTranscriptFound(
"test_video_id", ("en",), mock_transcript_list
)
mock_api.fetch.side_effect = original_exception
mock_api.list.return_value = mock_transcript_list
# Execute and assert exception is raised
with pytest.raises(NoTranscriptFound):
TranscribeYoutubeVideoBlock.get_transcript("test_video_id")

View File

@@ -749,10 +749,11 @@ class TestDataCreator:
"""Add credits to users."""
print("Adding credits to users...")
credit_model = get_user_credit_model()
for user in self.users:
try:
# Get user-specific credit model
credit_model = await get_user_credit_model(user["id"])
# Skip credits for disabled credit model to avoid errors
if (
hasattr(credit_model, "__class__")

View File

@@ -21,6 +21,7 @@ import random
from datetime import datetime
import prisma.enums
import pytest
from autogpt_libs.api_key.keysmith import APIKeySmith
from faker import Faker
from prisma import Json, Prisma
@@ -498,9 +499,6 @@ async def main():
if store_listing_versions and random.random() < 0.5
else None
),
"agentInput": (
Json({"test": "data"}) if random.random() < 0.3 else None
),
"onboardingAgentExecutionId": (
random.choice(agent_graph_executions).id
if agent_graph_executions and random.random() < 0.3
@@ -570,5 +568,11 @@ async def main():
print("Test data creation completed successfully!")
@pytest.mark.asyncio
@pytest.mark.integration
async def test_main_function_runs_without_errors():
await main()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -11,7 +11,6 @@
NEXT_PUBLIC_LAUNCHDARKLY_ENABLED=false
NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID=687ab1372f497809b131e06e
NEXT_PUBLIC_SHOW_BILLING_PAGE=false
NEXT_PUBLIC_TURNSTILE=disabled
NEXT_PUBLIC_REACT_QUERY_DEVTOOL=true

View File

@@ -0,0 +1,765 @@
<div align="center">
<h1>AutoGPT Frontend • Contributing ⌨️</h1>
<p>Next.js App Router • Client-first • Type-safe generated API hooks • Tailwind + shadcn/ui</p>
</div>
---
## ☕️ Summary
This document is your reference for contributing to the AutoGPT Frontend. It adapts legacy guidelines to our current stack and practices.
- Architecture and stack
- Component structure and design system
- Data fetching (generated API hooks)
- Feature flags
- Naming and code conventions
- Tooling, scripts, and testing
- PR process and checklist
This is a living document. Open a pull request any time to improve it.
---
## 🚀 Quick Start FAQ
New to the codebase? Here are shortcuts to common tasks:
### I need to make a new page
1. Create page in `src/app/(platform)/your-feature/page.tsx`
2. If it has logic, create `usePage.ts` hook next to it
3. Create sub-components in `components/` folder
4. Use generated API hooks for data fetching
5. If page needs auth, ensure it's in the `(platform)` route group
**Example structure:**
```
app/(platform)/dashboard/
page.tsx
useDashboardPage.ts
components/
StatsPanel/
StatsPanel.tsx
useStatsPanel.ts
```
See [Component structure](#-component-structure) and [Styling](#-styling) and [Data fetching patterns](#-data-fetching-patterns) sections.
### I need to update an existing component in a page
1. Find the page `src/app/(platform)/your-feature/page.tsx`
2. Check its `components/` folder
3. If needing to update its logic, check the `use[Component].ts` hook
4. If the update is related to rendering, check `[Component].tsx` file
See [Component structure](#-component-structure) and [Styling](#-styling) sections.
### I need to make a new API call and show it on the UI
1. Ensure the backend endpoint exists in the OpenAPI spec
2. Regenerate API client: `pnpm generate:api`
3. Import the generated hook by typing the operation name (auto-import)
4. Use the hook in your component/custom hook
5. Handle loading, error, and success states
**Example:**
```tsx
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
export function useAgentList() {
const { data, isLoading, isError, error } = useGetV2ListLibraryAgents();
return {
agents: data?.data || [],
isLoading,
isError,
error,
};
}
```
See [Data fetching patterns](#-data-fetching-patterns) for more examples.
### I need to create a new component in the Design System
1. Determine the atomic level: atom, molecule, or organism
2. Create folder: `src/components/[level]/ComponentName/`
3. Create `ComponentName.tsx` (render logic)
4. If logic exists, create `useComponentName.ts`
5. Create `ComponentName.stories.tsx` for Storybook
6. Use Tailwind + design tokens (avoid hardcoded values)
7. Only use Phosphor icons
8. Test in Storybook: `pnpm storybook`
9. Verify in Chromatic after PR
**Example structure:**
```
src/components/molecules/DataCard/
DataCard.tsx
DataCard.stories.tsx
useDataCard.ts
```
See [Component structure](#-component-structure) and [Styling](#-styling) sections.
---
## 📟 Contribution process
### 1) Branch off `dev`
- Branch from `dev` for features and fixes
- Keep PRs focused (aim for one ticket per PR)
- Use conventional commit messages with a scope (e.g., `feat(frontend): add X`)
### 2) Feature flags
If a feature will ship across multiple PRs, guard it with a flag so we can merge iteratively.
- Use [LaunchDarkly](https://www.launchdarkly.com) based flags (see Feature Flags below)
- Avoid long-lived feature branches
### 3) Open PR and get reviews ✅
Before requesting review:
- [x] Code follows architecture and conventions here
- [x] `pnpm format && pnpm lint && pnpm types` pass
- [x] Relevant tests pass locally: `pnpm test` (and/or Storybook tests)
- [x] If touching UI, validate against our design system and stories
### 4) Merge to `dev`
- Use squash merges
- Follow conventional commit message format for the squash title
---
## 📂 Architecture & Stack
### Next.js App Router
- We use the [Next.js App Router](https://nextjs.org/docs/app) in `src/app`
- Use [route segments](https://nextjs.org/docs/app/building-your-application/routing) with semantic URLs; no `pages/`
### Component good practices
- Default to client components
- Use server components only when:
- SEO requires server-rendered HTML, or
- Extreme first-byte performance justifies it
- If you render server-side data, prefer server-side prefetch + client hydration (see examples below and [React Query SSR & Hydration](https://tanstack.com/query/latest/docs/framework/react/guides/ssr))
- Prefer using [Next.js API routes](https://nextjs.org/docs/pages/building-your-application/routing/api-routes) when possible over [server actions](https://nextjs.org/docs/14/app/building-your-application/data-fetching/server-actions-and-mutations)
- Keep components small and simple
- favour composition and splitting large components into smaller bits of UI
- [colocate state](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster) when possible
- keep render/side-effects split for [separation of concerns](https://en.wikipedia.org/wiki/Separation_of_concerns)
- do not over-complicate or re-invent the wheel
**❓ Why a client-side first design vs server components/actions?**
While server components and actions are cool and cutting-edge, they introduce a layer of complexity which not always justified by the benefits they deliver. Defaulting to client-first keeps things simple in the mental model of the developer, specially for those developers less familiar with Next.js or heavy Front-end development.
### Data fetching: prefer generated API hooks
- We generate a type-safe client and React Query hooks from the backend OpenAPI spec via [Orval](https://orval.dev/)
- Prefer the generated hooks under `src/app/api/__generated__/endpoints/...`
- Treat `BackendAPI` and code under `src/lib/autogpt-server-api/*` as deprecated; do not introduce new usages
- Use [Zod](https://zod.dev/) schemas from the generated client where applicable
### State management
- Prefer [React Query](https://tanstack.com/query/latest/docs/framework/react/overview) for server state, colocated near consumers (see [state colocation](https://kentcdodds.com/blog/state-colocation-will-make-your-react-app-faster))
- Co-locate UI state inside components/hooks; keep global state minimal
### Styling and components
- [Tailwind CSS](https://tailwindcss.com/docs) + [shadcn/ui](https://ui.shadcn.com/) ([Radix Primitives](https://www.radix-ui.com/docs/primitives/overview/introduction) under the hood)
- Use the design system under `src/components` for primitives and building blocks
- Do not use anything under `src/components/_legacy__`; migrate away from it when touching old code
- Reference the design system catalog on Chromatic: [`https://dev--670f94474adee5e32c896b98.chromatic.com/`](https://dev--670f94474adee5e32c896b98.chromatic.com/)
- Use the [`tailwind-scrollbar`](https://www.npmjs.com/package/tailwind-scrollbar) plugin utilities for scrollbar styling
---
## 🧱 Component structure
For components, separate render logic from data/behavior, and keep implementation details local.
**Most components should follow this structure.** Pages are just bigger components made of smaller ones, and sub-components can have their own nested sub-components when dealing with complex features.
### Basic structure
When a component has non-trivial logic:
```
FeatureX/
FeatureX.tsx (render logic only)
useFeatureX.ts (hook; data fetching, behavior, state)
helpers.ts (pure helpers used by the hook)
components/ (optional, subcomponents local to FeatureX)
```
### Example: Page with nested components
```tsx
// Page composition
app/(platform)/dashboard/
page.tsx
useDashboardPage.ts
components/ # (Sub-components the dashboard page is made of)
StatsPanel/
StatsPanel.tsx
useStatsPanel.ts
helpers.ts
components/ # (Sub-components belonging to StatsPanel)
StatCard/
StatCard.tsx
ActivityFeed/
ActivityFeed.tsx
useActivityFeed.ts
```
### Guidelines
- Prefer function declarations for components and handlers
- Only use arrow functions for small inline lambdas (e.g., in `map`)
- Avoid barrel files and `index.ts` re-exports
- Keep component files focused and readable; push complex logic to `helpers.ts`
- Abstract reusable, cross-feature logic into `src/services/` or `src/lib/utils.ts` as appropriate
- Build components encapsulated so they can be easily reused and abstracted elsewhere
- Nest sub-components within a `components/` folder when they're local to the parent feature
### Exceptions
When to simplify the structure:
**Small hook logic (3-4 lines)**
If the hook logic is minimal, keep it inline with the render function:
```tsx
export function ActivityAlert() {
const [isVisible, setIsVisible] = useState(true);
if (!isVisible) return null;
return (
<Alert onClose={() => setIsVisible(false)}>New activity detected</Alert>
);
}
```
**Render-only components**
Components with no hook logic can be direct files in `components/` without a folder:
```
components/
ActivityAlert.tsx (render-only, no folder needed)
StatsPanel/ (has hook logic, needs folder)
StatsPanel.tsx
useStatsPanel.ts
```
### Hook file structure
When separating logic into a custom hook:
```tsx
// useStatsPanel.ts
export function useStatsPanel() {
const [data, setData] = useState<Stats[]>([]);
const [isLoading, setIsLoading] = useState(true);
useEffect(() => {
fetchStats().then(setData);
}, []);
return {
data,
isLoading,
refresh: () => fetchStats().then(setData),
};
}
```
Rules:
- **Always return an object** that exposes data and methods to the view
- **Export a single function** named after the component (e.g., `useStatsPanel` for `StatsPanel.tsx`)
- **Abstract into helpers.ts** when hook logic grows large, so the hook file remains readable by scanning without diving into implementation details
---
## 🔄 Data fetching patterns
All API hooks are generated from the backend OpenAPI specification using [Orval](https://orval.dev/). The hooks are type-safe and follow the operation names defined in the backend API.
### How to discover hooks
Most of the time you can rely on auto-import by typing the endpoint or operation name. Your IDE will suggest the generated hooks based on the OpenAPI operation IDs.
**Examples of hook naming patterns:**
- `GET /api/v1/notifications``useGetV1GetNotificationPreferences`
- `POST /api/v2/store/agents``usePostV2CreateStoreAgent`
- `DELETE /api/v2/store/submissions/{id}``useDeleteV2DeleteStoreSubmission`
- `GET /api/v2/library/agents``useGetV2ListLibraryAgents`
**Pattern**: `use{Method}{Version}{OperationName}`
You can also explore the generated hooks by browsing `src/app/api/__generated__/endpoints/` which is organized by API tags (e.g., `auth`, `store`, `library`).
**OpenAPI specs:**
- Production: [https://backend.agpt.co/openapi.json](https://backend.agpt.co/openapi.json)
- Staging: [https://dev-server.agpt.co/openapi.json](https://dev-server.agpt.co/openapi.json)
### Generated hooks (client)
Prefer the generated React Query hooks (via Orval + React Query):
```tsx
import { useGetV1GetNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
export function PreferencesPanel() {
const { data, isLoading, isError } = useGetV1GetNotificationPreferences({
query: {
select: (res) => res.data,
},
});
if (isLoading) return null;
if (isError) throw new Error("Failed to load preferences");
return <pre>{JSON.stringify(data, null, 2)}</pre>;
}
```
### Generated mutations (client)
```tsx
import { useQueryClient } from "@tanstack/react-query";
import {
useDeleteV2DeleteStoreSubmission,
getGetV2ListMySubmissionsQueryKey,
} from "@/app/api/__generated__/endpoints/store/store";
export function DeleteSubmissionButton({
submissionId,
}: {
submissionId: string;
}) {
const queryClient = useQueryClient();
const { mutateAsync: deleteSubmission, isPending } =
useDeleteV2DeleteStoreSubmission({
mutation: {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
},
},
});
async function onClick() {
await deleteSubmission({ submissionId });
}
return (
<button disabled={isPending} onClick={onClick}>
Delete
</button>
);
}
```
### Server-side prefetch + client hydration
Use server-side prefetch to improve TTFB while keeping the component tree client-first (see [React Query SSR & Hydration](https://tanstack.com/query/latest/docs/framework/react/guides/ssr)):
```tsx
// in a server component
import { getQueryClient } from "@/lib/tanstack-query/getQueryClient";
import { HydrationBoundary, dehydrate } from "@tanstack/react-query";
import {
prefetchGetV2ListStoreAgentsQuery,
prefetchGetV2ListStoreCreatorsQuery,
} from "@/app/api/__generated__/endpoints/store/store";
export default async function MarketplacePage() {
const queryClient = getQueryClient();
await Promise.all([
prefetchGetV2ListStoreAgentsQuery(queryClient, { featured: true }),
prefetchGetV2ListStoreAgentsQuery(queryClient, { sorted_by: "runs" }),
prefetchGetV2ListStoreCreatorsQuery(queryClient, {
featured: true,
sorted_by: "num_agents",
}),
]);
return (
<HydrationBoundary state={dehydrate(queryClient)}>
{/* Client component tree goes here */}
</HydrationBoundary>
);
}
```
Notes:
- Do not introduce new usages of `BackendAPI` or `src/lib/autogpt-server-api/*`
- Keep transformations and mapping logic close to the consumer (hook), not in the view
---
## ⚠️ Error handling
The app has multiple error handling strategies depending on the type of error:
### Render/runtime errors
Use `<ErrorCard />` to display render or runtime errors gracefully:
```tsx
import { ErrorCard } from "@/components/molecules/ErrorCard";
export function DataPanel() {
const { data, isLoading, isError, error } = useGetData();
if (isLoading) return <Skeleton />;
if (isError) return <ErrorCard error={error} />;
return <div>{data.content}</div>;
}
```
### API mutation errors
Display mutation errors using toast notifications:
```tsx
import { useToast } from "@/components/ui/use-toast";
export function useUpdateSettings() {
const { toast } = useToast();
const { mutateAsync: updateSettings } = useUpdateSettingsMutation({
mutation: {
onError: (error) => {
toast({
title: "Failed to update settings",
description: error.message,
variant: "destructive",
});
},
},
});
return { updateSettings };
}
```
### Manual Sentry capture
When needed, you can manually capture exceptions to Sentry:
```tsx
import * as Sentry from "@sentry/nextjs";
try {
await riskyOperation();
} catch (error) {
Sentry.captureException(error, {
tags: { context: "feature-x" },
extra: { metadata: additionalData },
});
throw error;
}
```
### Global error boundaries
The app has error boundaries already configured to:
- Capture uncaught errors globally and send them to Sentry
- Display a user-friendly error UI when something breaks
- Prevent the entire app from crashing
You don't need to wrap components in error boundaries manually unless you need custom error recovery logic.
---
## 🚩 Feature Flags
- Flags are powered by [LaunchDarkly](https://docs.launchdarkly.com/)
- Use the helper APIs under `src/services/feature-flags`
Check a flag in a client component:
```tsx
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
export function AgentActivityPanel() {
const enabled = useGetFlag(Flag.AGENT_ACTIVITY);
if (!enabled) return null;
return <div>Feature is enabled!</div>;
}
```
Protect a route or page component:
```tsx
import { withFeatureFlag } from "@/services/feature-flags/with-feature-flag";
export const MyFeaturePage = withFeatureFlag(function Page() {
return <div>My feature page</div>;
}, "my-feature-flag");
```
Local dev and Playwright:
- Set `NEXT_PUBLIC_PW_TEST=true` to use mocked flag values during local development and tests
Adding new flags:
1. Add the flag to the `Flag` enum and `FlagValues` type
2. Provide a mock value in the mock map
3. Configure the flag in LaunchDarkly
---
## 📙 Naming conventions
General:
- Variables and functions should read like plain English
- Prefer `const` over `let` unless reassignment is required
- Use searchable constants instead of magic numbers
Files:
- Components and hooks: `PascalCase` for component files, `camelCase` for hooks
- Other files: `kebab-case`
- Do not create barrel files or `index.ts` re-exports
Types:
- Prefer `interface` for object shapes
- Component props should be `interface Props { ... }`
- Use precise types; avoid `any` and unsafe casts
Parameters:
- If more than one parameter is needed, pass a single `Args` object for clarity
Comments:
- Keep comments minimal; code should be clear by itself
- Only document non-obvious intent, invariants, or caveats
Functions:
- Prefer function declarations for components and handlers
- Only use arrow functions for small inline callbacks
Control flow:
- Use early returns to reduce nesting
- Avoid catching errors unless you handle them meaningfully
---
## 🎨 Styling
- Use Tailwind utilities; prefer semantic, composable class names
- Use shadcn/ui components as building blocks when available
- Use the `tailwind-scrollbar` utilities for scrollbar styling
- Keep responsive and dark-mode behavior consistent with the design system
Additional requirements:
- Do not import shadcn primitives directly in feature code; only use components exposed in our design system under `src/components`. shadcn is a low-level skeleton we style on top of and is not meant to be consumed directly.
- Prefer design tokens over Tailwind's default theme whenever possible (e.g., color, spacing, radius, and typography tokens). Avoid hardcoded values and default palette if a token exists.
---
## ⚠️ Errors and ⏳ Loading
- **Errors**: Use the `ErrorCard` component from the design system to display API/HTTP errors and retry actions. Keep error derivation/mapping in hooks; pass the final message to the component.
- Component: `src/components/molecules/ErrorCard/ErrorCard.tsx`
- **Loading**: Use the `Skeleton` component(s) from the design system for loading states. Favor domain-appropriate skeleton layouts (lists, cards, tables) over spinners.
- See Storybook examples under Atoms/Skeleton for patterns.
---
## 🧭 Responsive and mobile-first
- Build mobile-first. Ensure new UI looks great from a 375px viewport width (iPhone SE) upwards.
- Validate layouts at common breakpoints (375, 768, 1024, 1280). Prefer stacking and progressive disclosure on small screens.
---
## 🧰 State for complex flows
For components/flows with complex state, multi-step wizards, or cross-component coordination, prefer a small co-located store using [Zustand](https://github.com/pmndrs/zustand).
Guidelines:
- Co-locate the store with the feature (e.g., `FeatureX/store.ts`).
- Expose typed selectors to minimize re-renders.
- Keep effects and API calls in hooks; stores hold state and pure actions.
Example: simple store with selectors
```ts
import { create } from "zustand";
interface WizardState {
step: number;
data: Record<string, unknown>;
next(): void;
back(): void;
setField(args: { key: string; value: unknown }): void;
}
export const useWizardStore = create<WizardState>((set) => ({
step: 0,
data: {},
next() {
set((state) => ({ step: state.step + 1 }));
},
back() {
set((state) => ({ step: Math.max(0, state.step - 1) }));
},
setField({ key, value }) {
set((state) => ({ data: { ...state.data, [key]: value } }));
},
}));
// Usage in a component (selectors keep updates scoped)
function WizardFooter() {
const step = useWizardStore((s) => s.step);
const next = useWizardStore((s) => s.next);
const back = useWizardStore((s) => s.back);
return (
<div className="flex items-center gap-2">
<button onClick={back} disabled={step === 0}>Back</button>
<button onClick={next}>Next</button>
</div>
);
}
```
Example: async action coordinated via hook + store
```ts
// FeatureX/useFeatureX.ts
import { useMutation } from "@tanstack/react-query";
import { useWizardStore } from "./store";
export function useFeatureX() {
const setField = useWizardStore((s) => s.setField);
const next = useWizardStore((s) => s.next);
const { mutateAsync: save, isPending } = useMutation({
mutationFn: async (payload: unknown) => {
// call API here
return payload;
},
onSuccess(data) {
setField({ key: "result", value: data });
next();
},
});
return { save, isSaving: isPending };
}
```
---
## 🖼 Icons
- Only use Phosphor Icons. Treat all other icon libraries as deprecated for new code.
- Package: `@phosphor-icons/react`
- Site: [`https://phosphoricons.com/`](https://phosphoricons.com/)
Example usage:
```tsx
import { Plus } from "@phosphor-icons/react";
export function CreateButton() {
return (
<button type="button" className="inline-flex items-center gap-2">
<Plus size={16} />
Create
</button>
);
}
```
---
## 🧪 Testing & Storybook
- End-to-end: [Playwright](https://playwright.dev/docs/intro) (`pnpm test`, `pnpm test-ui`)
- [Storybook](https://storybook.js.org/docs) for isolated UI development (`pnpm storybook` / `pnpm build-storybook`)
- For Storybook tests in CI, see [`@storybook/test-runner`](https://storybook.js.org/docs/writing-tests/test-runner) (`test-storybook:ci`)
- When changing components in `src/components`, update or add stories and visually verify in Storybook/Chromatic
---
## 🛠 Tooling & Scripts
Common scripts (see `package.json` for full list):
- `pnpm dev` — Start Next.js dev server (generates API client first)
- `pnpm build` — Build for production
- `pnpm start` — Start production server
- `pnpm lint` — ESLint + Prettier check
- `pnpm format` — Format code
- `pnpm types` — Type-check
- `pnpm storybook` — Run Storybook
- `pnpm test` — Run Playwright tests
Generated API client:
- `pnpm generate:api` — Fetch OpenAPI spec and regenerate the client
---
## ✅ PR checklist (Frontend)
- Client-first: server components only for SEO or extreme TTFB needs
- Uses generated API hooks; no new `BackendAPI` usages
- UI uses `src/components` primitives; no new `_legacy__` components
- Logic is separated into `use*.ts` and `helpers.ts` when non-trivial
- Reusable logic extracted to `src/services/` or `src/lib/utils.ts` when appropriate
- Navigation uses the Next.js router
- Lint, format, type-check, and tests pass locally
- Stories updated/added if UI changed; verified in Storybook
---
## ♻️ Migration guidance
When touching legacy code:
- Replace usages of `src/components/_legacy__/*` with the modern design system components under `src/components`
- Replace `BackendAPI` or `src/lib/autogpt-server-api/*` with generated API hooks
- Move presentational logic into render files and data/behavior into hooks
- Keep one-off transformations in local `helpers.ts`; move reusable logic to `src/services/` or `src/lib/utils.ts`
---
## 📚 References
- Design system (Chromatic): [`https://dev--670f94474adee5e32c896b98.chromatic.com/`](https://dev--670f94474adee5e32c896b98.chromatic.com/)
- Project README for setup and API client examples: `autogpt_platform/frontend/README.md`
- Conventional Commits: [conventionalcommits.org](https://www.conventionalcommits.org/)

View File

@@ -4,20 +4,12 @@ This is the frontend for AutoGPT's next generation
This project uses [**pnpm**](https://pnpm.io/) as the package manager via **corepack**. [Corepack](https://github.com/nodejs/corepack) is a Node.js tool that automatically manages package managers without requiring global installations.
For architecture, conventions, data fetching, feature flags, design system usage, state management, and PR process, see [CONTRIBUTING.md](./CONTRIBUTING.md).
### Prerequisites
Make sure you have Node.js 16.10+ installed. Corepack is included with Node.js by default.
### ⚠️ Migrating from yarn
> This project was previously using yarn1, make sure to clean up the old files if you set it up previously with yarn:
>
> ```bash
> rm -f yarn.lock && rm -rf node_modules
> ```
>
> Then follow the setup steps below.
## Setup
### 1. **Enable corepack** (run this once on your system):
@@ -96,184 +88,13 @@ Every time a new Front-end dependency is added by you or others, you will need t
This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.
## 🔄 Data Fetching Strategy
## 🔄 Data Fetching
> [!NOTE]
> You don't need to run the OpenAPI commands below to run the Front-end. You will only need to run them when adding or modifying endpoints on the Backend API and wanting to use those on the Frontend.
This project uses an auto-generated API client powered by [**Orval**](https://orval.dev/), which creates type-safe API clients from OpenAPI specifications.
### How It Works
1. **Backend Requirements**: Each API endpoint needs a summary and tag in the OpenAPI spec
2. **Operation ID Generation**: FastAPI generates operation IDs using the pattern `{method}{tag}{summary}`
3. **Spec Fetching**: The OpenAPI spec is fetched from `http://localhost:8006/openapi.json` and saved to the frontend
4. **Spec Transformation**: The OpenAPI spec is cleaned up using a custom transformer (see `autogpt_platform/frontend/src/app/api/transformers`)
5. **Client Generation**: Auto-generated client includes TypeScript types, API endpoints, and Zod schemas, organized by tags
### API Client Commands
```bash
# Fetch OpenAPI spec from backend and generate client
pnpm generate:api
# Only fetch the OpenAPI spec
pnpm fetch:openapi
# Only generate the client (after spec is fetched)
pnpm generate:api-client
```
### Using the Generated Client
The generated client provides React Query hooks for both queries and mutations:
#### Queries (GET requests)
```typescript
import { useGetV1GetNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
const { data, isLoading, isError } = useGetV1GetNotificationPreferences({
query: {
select: (res) => res.data,
// Other React Query options
},
});
```
#### Mutations (POST, PUT, DELETE requests)
```typescript
import { useDeleteV2DeleteStoreSubmission } from "@/app/api/__generated__/endpoints/store/store";
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
import { useQueryClient } from "@tanstack/react-query";
const queryClient = useQueryClient();
const { mutateAsync: deleteSubmission } = useDeleteV2DeleteStoreSubmission({
mutation: {
onSuccess: () => {
// Invalidate related queries to refresh data
queryClient.invalidateQueries({
queryKey: getGetV2ListMySubmissionsQueryKey(),
});
},
},
});
// Usage
await deleteSubmission({
submissionId: submission_id,
});
```
#### Server Actions
For server-side operations, you can also use the generated client functions directly:
```typescript
import { postV1UpdateNotificationPreferences } from "@/app/api/__generated__/endpoints/auth/auth";
// In a server action
const preferences = {
email: "user@example.com",
preferences: {
AGENT_RUN: true,
ZERO_BALANCE: false,
// ... other preferences
},
daily_limit: 0,
};
await postV1UpdateNotificationPreferences(preferences);
```
#### Server-Side Prefetching
For server-side components, you can prefetch data on the server and hydrate it in the client cache. This allows immediate access to cached data when queries are called:
```typescript
import { getQueryClient } from "@/lib/tanstack-query/getQueryClient";
import {
prefetchGetV2ListStoreAgentsQuery,
prefetchGetV2ListStoreCreatorsQuery
} from "@/app/api/__generated__/endpoints/store/store";
import { HydrationBoundary, dehydrate } from "@tanstack/react-query";
// In your server component
const queryClient = getQueryClient();
await Promise.all([
prefetchGetV2ListStoreAgentsQuery(queryClient, {
featured: true,
}),
prefetchGetV2ListStoreAgentsQuery(queryClient, {
sorted_by: "runs",
}),
prefetchGetV2ListStoreCreatorsQuery(queryClient, {
featured: true,
sorted_by: "num_agents",
}),
]);
return (
<HydrationBoundary state={dehydrate(queryClient)}>
<MainMarkeplacePage />
</HydrationBoundary>
);
```
This pattern improves performance by serving pre-fetched data from the server while maintaining the benefits of client-side React Query features.
### Configuration
The Orval configuration is located in `autogpt_platform/frontend/orval.config.ts`. It generates two separate clients:
1. **autogpt_api_client**: React Query hooks for client-side data fetching
2. **autogpt_zod_schema**: Zod schemas for validation
For more details, see the [Orval documentation](https://orval.dev/) or check the configuration file.
See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidance on generated API hooks, SSR + hydration patterns, and usage examples. You generally do not need to run OpenAPI commands unless adding/modifying backend endpoints.
## 🚩 Feature Flags
This project uses [LaunchDarkly](https://launchdarkly.com/) for feature flags, allowing us to control feature rollouts and A/B testing.
### Using Feature Flags
#### Check if a feature is enabled
```typescript
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
function MyComponent() {
const isAgentActivityEnabled = useGetFlag(Flag.AGENT_ACTIVITY);
if (!isAgentActivityEnabled) {
return null; // Hide feature
}
return <div>Feature is enabled!</div>;
}
```
#### Protect entire components
```typescript
import { withFeatureFlag } from "@/services/feature-flags/with-feature-flag";
const MyFeaturePage = withFeatureFlag(MyPageComponent, "my-feature-flag");
```
### Testing with Feature Flags
For local development or running Playwright tests locally, use mocked feature flags by setting `NEXT_PUBLIC_PW_TEST=true` in your `.env` file. This bypasses LaunchDarkly and uses the mock values defined in the code.
### Adding New Flags
1. Add the flag to the `Flag` enum in `use-get-flag.ts`
2. Add the flag type to `FlagValues` type
3. Add mock value to `mockFlags` for testing
4. Configure the flag in LaunchDarkly dashboard
See [CONTRIBUTING.md](./CONTRIBUTING.md) for feature flag usage patterns, local development with mocks, and how to add new flags.
## 🚚 Deploy
@@ -333,7 +154,7 @@ By integrating Storybook into our development workflow, we can streamline UI dev
- [**Tailwind CSS**](https://tailwindcss.com/) - Utility-first CSS framework
- [**shadcn/ui**](https://ui.shadcn.com/) - Re-usable components built with Radix UI and Tailwind CSS
- [**Radix UI**](https://www.radix-ui.com/) - Headless UI components for accessibility
- [**Lucide React**](https://lucide.dev/guide/packages/lucide-react) - Beautiful & consistent icons
- [**Phosphor Icons**](https://phosphoricons.com/) - Icon set used across the app
- [**Framer Motion**](https://motion.dev/) - Animation library for React
### Development & Testing

View File

@@ -2,18 +2,11 @@
// The config you add here will be used whenever a users loads a page in their browser.
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import {
AppEnv,
BehaveAs,
getAppEnv,
getBehaveAs,
getEnvironmentStr,
} from "@/lib/utils";
import { environment } from "@/services/environment";
import * as Sentry from "@sentry/nextjs";
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
const isProdOrDev = environment.isProd() || environment.isDev();
const isCloud = environment.isCloud();
const isDisabled = process.env.DISABLE_SENTRY === "true";
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
@@ -21,7 +14,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
environment: getEnvironmentStr(),
environment: environment.getEnvironmentStr(),
enabled: shouldEnable,

View File

@@ -55,7 +55,7 @@
"@sentry/nextjs": "10.15.0",
"@supabase/ssr": "0.6.1",
"@supabase/supabase-js": "2.55.0",
"@tanstack/react-query": "5.85.3",
"@tanstack/react-query": "5.87.1",
"@tanstack/react-table": "8.21.3",
"@types/jaro-winkler": "0.2.4",
"@vercel/analytics": "1.5.0",
@@ -103,7 +103,7 @@
"shepherd.js": "14.5.1",
"sonner": "2.0.7",
"tailwind-merge": "2.6.0",
"tailwind-scrollbar": "4.0.2",
"tailwind-scrollbar": "3.1.0",
"tailwindcss-animate": "1.0.7",
"uuid": "11.1.0",
"vaul": "1.1.2",

View File

@@ -99,8 +99,8 @@ importers:
specifier: 2.55.0
version: 2.55.0
'@tanstack/react-query':
specifier: 5.85.3
version: 5.85.3(react@18.3.1)
specifier: 5.87.1
version: 5.87.1(react@18.3.1)
'@tanstack/react-table':
specifier: 8.21.3
version: 8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
@@ -243,8 +243,8 @@ importers:
specifier: 2.6.0
version: 2.6.0
tailwind-scrollbar:
specifier: 4.0.2
version: 4.0.2(react@18.3.1)(tailwindcss@3.4.17)
specifier: 3.1.0
version: 3.1.0(tailwindcss@3.4.17)
tailwindcss-animate:
specifier: 1.0.7
version: 1.0.7(tailwindcss@3.4.17)
@@ -287,7 +287,7 @@ importers:
version: 5.86.0(eslint@8.57.1)(typescript@5.9.2)
'@tanstack/react-query-devtools':
specifier: 5.87.3
version: 5.87.3(@tanstack/react-query@5.85.3(react@18.3.1))(react@18.3.1)
version: 5.87.3(@tanstack/react-query@5.87.1(react@18.3.1))(react@18.3.1)
'@types/canvas-confetti':
specifier: 1.9.0
version: 1.9.0
@@ -947,10 +947,6 @@ packages:
peerDependencies:
'@babel/core': ^7.0.0-0
'@babel/runtime@7.28.3':
resolution: {integrity: sha512-9uIQ10o0WGdpP6GDhXcdOJPJuDgFtIDtN/9+ArJQ2NAfAmiuhTQdzkaTGR33v43GYS2UrSA0eX2pPPHoFVvpxA==}
engines: {node: '>=6.9.0'}
'@babel/runtime@7.28.4':
resolution: {integrity: sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==}
engines: {node: '>=6.9.0'}
@@ -985,9 +981,6 @@ packages:
'@emnapi/core@1.5.0':
resolution: {integrity: sha512-sbP8GzB1WDzacS8fgNPpHlp6C9VZe+SJP3F90W9rLemaQj2PzIuTEl1qDOYQf58YIpyjViI24y9aPWCjEzY2cg==}
'@emnapi/runtime@1.4.5':
resolution: {integrity: sha512-++LApOtY0pEEz1zrd9vy1/zXVaVJJ/EbAF3u0fXIzPJEDtnITsBGbbK0EkM72amhl/R5b+5xx0Y/QhcVOpuulg==}
'@emnapi/runtime@1.5.0':
resolution: {integrity: sha512-97/BJ3iXHww3djw6hYIfErCZFee7qCtrneuLa20UXFCOTCfBM2cvQHjWJ2EG0s0MtdNwInarqCTz35i4wWXHsQ==}
@@ -1159,12 +1152,6 @@ packages:
cpu: [x64]
os: [win32]
'@eslint-community/eslint-utils@4.7.0':
resolution: {integrity: sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==}
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
peerDependencies:
eslint: ^6.0.0 || ^7.0.0 || >=8.0.0
'@eslint-community/eslint-utils@4.9.0':
resolution: {integrity: sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==}
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
@@ -2856,8 +2843,8 @@ packages:
peerDependencies:
eslint: ^8.57.0 || ^9.0.0
'@tanstack/query-core@5.85.3':
resolution: {integrity: sha512-9Ne4USX83nHmRuEYs78LW+3lFEEO2hBDHu7mrdIgAFx5Zcrs7ker3n/i8p4kf6OgKExmaDN5oR0efRD7i2J0DQ==}
'@tanstack/query-core@5.87.1':
resolution: {integrity: sha512-HOFHVvhOCprrWvtccSzc7+RNqpnLlZ5R6lTmngb8aq7b4rc2/jDT0w+vLdQ4lD9bNtQ+/A4GsFXy030Gk4ollA==}
'@tanstack/query-devtools@5.87.3':
resolution: {integrity: sha512-LkzxzSr2HS1ALHTgDmJH5eGAVsSQiuwz//VhFW5OqNk0OQ+Fsqba0Tsf+NzWRtXYvpgUqwQr4b2zdFZwxHcGvg==}
@@ -2868,8 +2855,8 @@ packages:
'@tanstack/react-query': ^5.87.1
react: ^18 || ^19
'@tanstack/react-query@5.85.3':
resolution: {integrity: sha512-AqU8TvNh5GVIE8I+TUU0noryBRy7gOY0XhSayVXmOPll4UkZeLWKDwi0rtWOZbwLRCbyxorfJ5DIjDqE7GXpcQ==}
'@tanstack/react-query@5.87.1':
resolution: {integrity: sha512-YKauf8jfMowgAqcxj96AHs+Ux3m3bWT1oSVKamaRPXSnW2HqSznnTCEkAVqctF1e/W9R/mPcyzzINIgpOH94qg==}
peerDependencies:
react: ^18 || ^19
@@ -3045,9 +3032,6 @@ packages:
'@types/phoenix@1.6.6':
resolution: {integrity: sha512-PIzZZlEppgrpoT2QgbnDU+MMzuR6BbCjllj0bM70lWoejMeNJAxCchxnv7J3XFkI8MpygtRpzXrIlmWUBclP5A==}
'@types/prismjs@1.26.5':
resolution: {integrity: sha512-AUZTa7hQ2KY5L7AmtSiqxlhWxb4ina0yd8hNbl4TWuqnv/pFP0nDMb3YrfSBf4hJVGLh2YEIBfKaBW/9UEl6IQ==}
'@types/prop-types@15.7.15':
resolution: {integrity: sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==}
@@ -3740,9 +3724,6 @@ packages:
camelize@1.0.1:
resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==}
caniuse-lite@1.0.30001735:
resolution: {integrity: sha512-EV/laoX7Wq2J9TQlyIXRxTJqIw4sxfXS4OYgudGxBYRuTv0q7AM6yMEpU/Vo1I94thg9U6EZ2NfZx9GJq83u7w==}
caniuse-lite@1.0.30001741:
resolution: {integrity: sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==}
@@ -4108,15 +4089,6 @@ packages:
supports-color:
optional: true
debug@4.4.1:
resolution: {integrity: sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==}
engines: {node: '>=6.0'}
peerDependencies:
supports-color: '*'
peerDependenciesMeta:
supports-color:
optional: true
debug@4.4.3:
resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==}
engines: {node: '>=6.0'}
@@ -6220,11 +6192,6 @@ packages:
resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==}
engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0}
prism-react-renderer@2.4.1:
resolution: {integrity: sha512-ey8Ls/+Di31eqzUxC46h8MksNuGx/n0AAC8uKpwFau4RPDYLuE3EXTp8N8G2vX2N7UC/+IXeNUnlWBGGcAG+Ig==}
peerDependencies:
react: '>=16.0.0'
process-nextick-args@2.0.1:
resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==}
@@ -6949,11 +6916,11 @@ packages:
tailwind-merge@2.6.0:
resolution: {integrity: sha512-P+Vu1qXfzediirmHOC3xKGAYeZtPcV9g76X+xg2FD4tYgR71ewMA35Y3sCz3zhiN/dwefRpJX0yBcgwi1fXNQA==}
tailwind-scrollbar@4.0.2:
resolution: {integrity: sha512-wAQiIxAPqk0MNTPptVe/xoyWi27y+NRGnTwvn4PQnbvB9kp8QUBiGl/wsfoVBHnQxTmhXJSNt9NHTmcz9EivFA==}
tailwind-scrollbar@3.1.0:
resolution: {integrity: sha512-pmrtDIZeHyu2idTejfV59SbaJyvp1VRjYxAjZBH0jnyrPRo6HL1kD5Glz8VPagasqr6oAx6M05+Tuw429Z8jxg==}
engines: {node: '>=12.13.0'}
peerDependencies:
tailwindcss: 4.x
tailwindcss: 3.x
tailwindcss-animate@1.0.7:
resolution: {integrity: sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==}
@@ -7567,7 +7534,7 @@ snapshots:
'@babel/types': 7.28.4
'@jridgewell/remapping': 2.3.5
convert-source-map: 2.0.0
debug: 4.4.1
debug: 4.4.3
gensync: 1.0.0-beta.2
json5: 2.2.3
semver: 6.3.1
@@ -7619,7 +7586,7 @@ snapshots:
'@babel/core': 7.28.4
'@babel/helper-compilation-targets': 7.27.2
'@babel/helper-plugin-utils': 7.27.1
debug: 4.4.1
debug: 4.4.3
lodash.debounce: 4.0.8
resolve: 1.22.10
transitivePeerDependencies:
@@ -8270,8 +8237,6 @@ snapshots:
transitivePeerDependencies:
- supports-color
'@babel/runtime@7.28.3': {}
'@babel/runtime@7.28.4': {}
'@babel/template@7.27.2':
@@ -8288,7 +8253,7 @@ snapshots:
'@babel/parser': 7.28.4
'@babel/template': 7.27.2
'@babel/types': 7.28.4
debug: 4.4.1
debug: 4.4.3
transitivePeerDependencies:
- supports-color
@@ -8325,11 +8290,6 @@ snapshots:
tslib: 2.8.1
optional: true
'@emnapi/runtime@1.4.5':
dependencies:
tslib: 2.8.1
optional: true
'@emnapi/runtime@1.5.0':
dependencies:
tslib: 2.8.1
@@ -8426,11 +8386,6 @@ snapshots:
'@esbuild/win32-x64@0.25.9':
optional: true
'@eslint-community/eslint-utils@4.7.0(eslint@8.57.1)':
dependencies:
eslint: 8.57.1
eslint-visitor-keys: 3.4.3
'@eslint-community/eslint-utils@4.9.0(eslint@8.57.1)':
dependencies:
eslint: 8.57.1
@@ -8441,7 +8396,7 @@ snapshots:
'@eslint/eslintrc@2.1.4':
dependencies:
ajv: 6.12.6
debug: 4.4.1
debug: 4.4.3
espree: 9.6.1
globals: 13.24.0
ignore: 5.3.2
@@ -8491,7 +8446,7 @@ snapshots:
'@humanwhocodes/config-array@0.13.0':
dependencies:
'@humanwhocodes/object-schema': 2.0.3
debug: 4.4.1
debug: 4.4.3
minimatch: 3.1.2
transitivePeerDependencies:
- supports-color
@@ -8592,7 +8547,7 @@ snapshots:
'@img/sharp-wasm32@0.34.3':
dependencies:
'@emnapi/runtime': 1.4.5
'@emnapi/runtime': 1.5.0
optional: true
'@img/sharp-win32-arm64@0.34.3':
@@ -9041,7 +8996,7 @@ snapshots:
ajv: 8.17.1
chalk: 4.1.2
compare-versions: 6.1.1
debug: 4.4.1
debug: 4.4.3
esbuild: 0.25.9
esutils: 2.0.3
fs-extra: 11.3.1
@@ -10373,7 +10328,7 @@ snapshots:
'@storybook/react-docgen-typescript-plugin@1.0.6--canary.9.0c3f3b7.0(typescript@5.9.2)(webpack@5.101.3(esbuild@0.25.9))':
dependencies:
debug: 4.4.1
debug: 4.4.3
endent: 2.1.0
find-cache-dir: 3.3.2
flat-cache: 3.2.0
@@ -10460,19 +10415,19 @@ snapshots:
- supports-color
- typescript
'@tanstack/query-core@5.85.3': {}
'@tanstack/query-core@5.87.1': {}
'@tanstack/query-devtools@5.87.3': {}
'@tanstack/react-query-devtools@5.87.3(@tanstack/react-query@5.85.3(react@18.3.1))(react@18.3.1)':
'@tanstack/react-query-devtools@5.87.3(@tanstack/react-query@5.87.1(react@18.3.1))(react@18.3.1)':
dependencies:
'@tanstack/query-devtools': 5.87.3
'@tanstack/react-query': 5.85.3(react@18.3.1)
'@tanstack/react-query': 5.87.1(react@18.3.1)
react: 18.3.1
'@tanstack/react-query@5.85.3(react@18.3.1)':
'@tanstack/react-query@5.87.1(react@18.3.1)':
dependencies:
'@tanstack/query-core': 5.85.3
'@tanstack/query-core': 5.87.1
react: 18.3.1
'@tanstack/react-table@8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)':
@@ -10664,8 +10619,6 @@ snapshots:
'@types/phoenix@1.6.6': {}
'@types/prismjs@1.26.5': {}
'@types/prop-types@15.7.15': {}
'@types/react-dom@18.3.5(@types/react@18.3.17)':
@@ -10734,7 +10687,7 @@ snapshots:
'@typescript-eslint/types': 8.43.0
'@typescript-eslint/typescript-estree': 8.43.0(typescript@5.9.2)
'@typescript-eslint/visitor-keys': 8.43.0
debug: 4.4.1
debug: 4.4.3
eslint: 8.57.1
typescript: 5.9.2
transitivePeerDependencies:
@@ -10744,7 +10697,7 @@ snapshots:
dependencies:
'@typescript-eslint/tsconfig-utils': 8.43.0(typescript@5.9.2)
'@typescript-eslint/types': 8.43.0
debug: 4.4.1
debug: 4.4.3
typescript: 5.9.2
transitivePeerDependencies:
- supports-color
@@ -10763,7 +10716,7 @@ snapshots:
'@typescript-eslint/types': 8.43.0
'@typescript-eslint/typescript-estree': 8.43.0(typescript@5.9.2)
'@typescript-eslint/utils': 8.43.0(eslint@8.57.1)(typescript@5.9.2)
debug: 4.4.1
debug: 4.4.3
eslint: 8.57.1
ts-api-utils: 2.1.0(typescript@5.9.2)
typescript: 5.9.2
@@ -10778,7 +10731,7 @@ snapshots:
'@typescript-eslint/tsconfig-utils': 8.43.0(typescript@5.9.2)
'@typescript-eslint/types': 8.43.0
'@typescript-eslint/visitor-keys': 8.43.0
debug: 4.4.1
debug: 4.4.3
fast-glob: 3.3.3
is-glob: 4.0.3
minimatch: 9.0.5
@@ -11395,8 +11348,6 @@ snapshots:
camelize@1.0.1: {}
caniuse-lite@1.0.30001735: {}
caniuse-lite@1.0.30001741: {}
case-sensitive-paths-webpack-plugin@2.4.0: {}
@@ -11598,7 +11549,7 @@ snapshots:
dependencies:
cipher-base: 1.0.6
inherits: 2.0.4
ripemd160: 2.0.1
ripemd160: 2.0.2
sha.js: 2.4.12
create-hash@1.2.0:
@@ -11612,9 +11563,9 @@ snapshots:
create-hmac@1.1.7:
dependencies:
cipher-base: 1.0.6
create-hash: 1.1.3
create-hash: 1.2.0
inherits: 2.0.4
ripemd160: 2.0.1
ripemd160: 2.0.2
safe-buffer: 5.2.1
sha.js: 2.4.12
@@ -11772,10 +11723,6 @@ snapshots:
dependencies:
ms: 2.1.3
debug@4.4.1:
dependencies:
ms: 2.1.3
debug@4.4.3:
dependencies:
ms: 2.1.3
@@ -12077,7 +12024,7 @@ snapshots:
esbuild-register@3.6.0(esbuild@0.25.9):
dependencies:
debug: 4.4.1
debug: 4.4.3
esbuild: 0.25.9
transitivePeerDependencies:
- supports-color
@@ -12148,7 +12095,7 @@ snapshots:
eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0)(eslint@8.57.1):
dependencies:
'@nolyfill/is-core-module': 1.0.39
debug: 4.4.1
debug: 4.4.3
eslint: 8.57.1
get-tsconfig: 4.10.1
is-bun-module: 2.0.0
@@ -12270,7 +12217,7 @@ snapshots:
eslint@8.57.1:
dependencies:
'@eslint-community/eslint-utils': 4.7.0(eslint@8.57.1)
'@eslint-community/eslint-utils': 4.9.0(eslint@8.57.1)
'@eslint-community/regexpp': 4.12.1
'@eslint/eslintrc': 2.1.4
'@eslint/js': 8.57.1
@@ -12281,7 +12228,7 @@ snapshots:
ajv: 6.12.6
chalk: 4.1.2
cross-spawn: 7.0.6
debug: 4.4.1
debug: 4.4.3
doctrine: 3.0.0
escape-string-regexp: 4.0.0
eslint-scope: 7.2.2
@@ -13654,7 +13601,7 @@ snapshots:
micromark@4.0.2:
dependencies:
'@types/debug': 4.1.12
debug: 4.4.1
debug: 4.4.3
decode-named-character-reference: 1.2.0
devlop: 1.1.0
micromark-core-commonmark: 2.0.3
@@ -13790,7 +13737,7 @@ snapshots:
dependencies:
'@next/env': 15.4.7
'@swc/helpers': 0.5.15
caniuse-lite: 1.0.30001735
caniuse-lite: 1.0.30001741
postcss: 8.4.31
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
@@ -14311,12 +14258,6 @@ snapshots:
ansi-styles: 5.2.0
react-is: 17.0.2
prism-react-renderer@2.4.1(react@18.3.1):
dependencies:
'@types/prismjs': 1.26.5
clsx: 2.1.1
react: 18.3.1
process-nextick-args@2.0.1: {}
process@0.11.10: {}
@@ -14495,7 +14436,7 @@ snapshots:
react-window@1.8.11(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
'@babel/runtime': 7.28.3
'@babel/runtime': 7.28.4
memoize-one: 5.2.1
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
@@ -14716,7 +14657,7 @@ snapshots:
require-in-the-middle@7.5.2:
dependencies:
debug: 4.4.1
debug: 4.4.3
module-details-from-path: 1.0.4
resolve: 1.22.10
transitivePeerDependencies:
@@ -15259,12 +15200,9 @@ snapshots:
tailwind-merge@2.6.0: {}
tailwind-scrollbar@4.0.2(react@18.3.1)(tailwindcss@3.4.17):
tailwind-scrollbar@3.1.0(tailwindcss@3.4.17):
dependencies:
prism-react-renderer: 2.4.1(react@18.3.1)
tailwindcss: 3.4.17
transitivePeerDependencies:
- react
tailwindcss-animate@1.0.7(tailwindcss@3.4.17):
dependencies:

View File

@@ -1,16 +1,16 @@
#!/usr/bin/env node
import { getAgptServerBaseUrl } from "@/lib/env-config";
import { execSync } from "child_process";
import * as path from "path";
import * as fs from "fs";
import * as os from "os";
import { environment } from "@/services/environment";
function fetchOpenApiSpec(): void {
const args = process.argv.slice(2);
const forceFlag = args.includes("--force");
const baseUrl = getAgptServerBaseUrl();
const baseUrl = environment.getAGPTServerBaseUrl();
const openApiUrl = `${baseUrl}/openapi.json`;
const outputPath = path.join(
__dirname,

View File

@@ -3,18 +3,11 @@
// Note that this config is unrelated to the Vercel Edge Runtime and is also required when running locally.
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import { environment } from "@/services/environment";
import * as Sentry from "@sentry/nextjs";
import {
AppEnv,
BehaveAs,
getAppEnv,
getBehaveAs,
getEnvironmentStr,
} from "./src/lib/utils";
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
const isProdOrDev = environment.isProd() || environment.isDev();
const isCloud = environment.isCloud();
const isDisabled = process.env.DISABLE_SENTRY === "true";
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
@@ -22,7 +15,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
environment: getEnvironmentStr(),
environment: environment.getEnvironmentStr(),
enabled: shouldEnable,
@@ -40,7 +33,7 @@ Sentry.init({
enableLogs: true,
integrations: [
Sentry.captureConsoleIntegration(),
Sentry.captureConsoleIntegration({ levels: ["fatal", "error", "warn"] }),
Sentry.extraErrorDataIntegration(),
],
});

View File

@@ -2,19 +2,12 @@
// The config you add here will be used whenever the server handles a request.
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
import {
AppEnv,
BehaveAs,
getAppEnv,
getBehaveAs,
getEnvironmentStr,
} from "@/lib/utils";
import { environment } from "@/services/environment";
import * as Sentry from "@sentry/nextjs";
// import { NodeProfilingIntegration } from "@sentry/profiling-node";
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
const isProdOrDev = environment.isProd() || environment.isDev();
const isCloud = environment.isCloud();
const isDisabled = process.env.DISABLE_SENTRY === "true";
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
@@ -22,7 +15,7 @@ const shouldEnable = !isDisabled && isProdOrDev && isCloud;
Sentry.init({
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
environment: getEnvironmentStr(),
environment: environment.getEnvironmentStr(),
enabled: shouldEnable,

View File

@@ -10,9 +10,9 @@ import OnboardingAgentCard from "../components/OnboardingAgentCard";
import { useEffect, useState } from "react";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { StoreAgentDetails } from "@/lib/autogpt-server-api";
import { finishOnboarding } from "../6-congrats/actions";
import { isEmptyOrWhitespace } from "@/lib/utils";
import { useOnboarding } from "../../../../providers/onboarding/onboarding-provider";
import { finishOnboarding } from "../6-congrats/actions";
export default function Page() {
const { state, updateState } = useOnboarding(4, "INTEGRATIONS");
@@ -24,6 +24,7 @@ export default function Page() {
if (agents.length < 2) {
finishOnboarding();
}
setAgents(agents);
});
}, [api, setAgents]);

View File

@@ -0,0 +1,62 @@
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { useState } from "react";
import { getSchemaDefaultCredentials } from "../../helpers";
import { areAllCredentialsSet, getCredentialFields } from "./helpers";
type Credential = CredentialsMetaInput | undefined;
type Credentials = Record<string, Credential>;
type Props = {
agent: GraphMeta | null;
siblingInputs?: Record<string, any>;
onCredentialsChange: (
credentials: Record<string, CredentialsMetaInput>,
) => void;
onValidationChange: (isValid: boolean) => void;
onLoadingChange: (isLoading: boolean) => void;
};
export function AgentOnboardingCredentials(props: Props) {
const [inputCredentials, setInputCredentials] = useState<Credentials>({});
const fields = getCredentialFields(props.agent);
const required = Object.keys(fields || {}).length > 0;
if (!required) return null;
function handleSelectCredentials(key: string, value: Credential) {
const updated = { ...inputCredentials, [key]: value };
setInputCredentials(updated);
const sanitized: Record<string, CredentialsMetaInput> = {};
for (const [k, v] of Object.entries(updated)) {
if (v) sanitized[k] = v;
}
props.onCredentialsChange(sanitized);
const isValid = !required || areAllCredentialsSet(fields, updated);
props.onValidationChange(isValid);
}
return (
<>
{Object.entries(fields).map(([key, inputSubSchema]) => (
<div key={key} className="mt-4">
<CredentialsInput
schema={inputSubSchema}
selectedCredentials={
inputCredentials[key] ??
getSchemaDefaultCredentials(inputSubSchema)
}
onSelectCredentials={(value) => handleSelectCredentials(key, value)}
siblingInputs={props.siblingInputs}
onLoaded={(loaded) => props.onLoadingChange(!loaded)}
/>
</div>
))}
</>
);
}

View File

@@ -0,0 +1,32 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
export function getCredentialFields(
agent: GraphMeta | null,
): AgentCredentialsFields {
if (!agent) return {};
const hasNoInputs =
!agent.credentials_input_schema ||
typeof agent.credentials_input_schema !== "object" ||
!("properties" in agent.credentials_input_schema) ||
!agent.credentials_input_schema.properties;
if (hasNoInputs) return {};
return agent.credentials_input_schema.properties as AgentCredentialsFields;
}
export type AgentCredentialsFields = Record<
string,
BlockIOCredentialsSubSchema
>;
export function areAllCredentialsSet(
fields: AgentCredentialsFields,
inputs: Record<string, CredentialsMetaInput | undefined>,
) {
const required = Object.keys(fields || {});
return required.every((k) => Boolean(inputs[k]));
}

View File

@@ -0,0 +1,45 @@
import { cn } from "@/lib/utils";
import { OnboardingText } from "../../components/OnboardingText";
type RunAgentHintProps = {
handleNewRun: () => void;
};
export function RunAgentHint(props: RunAgentHintProps) {
return (
<div className="ml-[104px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">Run your first agent</OnboardingText>
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
A &apos;run&apos; is when your agent starts working on a task
</span>
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
Click on <b>New Run</b> below to try it out
</span>
<div
onClick={props.handleNewRun}
className={cn(
"mt-16 flex h-[68px] w-[330px] items-center justify-center rounded-xl border-2 border-violet-700 bg-neutral-50",
"cursor-pointer transition-all duration-200 ease-in-out hover:bg-violet-50",
)}
>
<svg
width="38"
height="38"
viewBox="0 0 32 32"
xmlns="http://www.w3.org/2000/svg"
>
<g stroke="#6d28d9" strokeWidth="1.2" strokeLinecap="round">
<line x1="16" y1="8" x2="16" y2="24" />
<line x1="8" y1="16" x2="24" y2="16" />
</g>
</svg>
<span className="ml-3 font-sans text-[19px] font-medium leading-normal text-violet-700">
New run
</span>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,52 @@
import { StoreAgentDetails } from "@/app/api/__generated__/models/storeAgentDetails";
import StarRating from "../../components/StarRating";
import SmartImage from "@/components/__legacy__/SmartImage";
type Props = {
storeAgent: StoreAgentDetails | null;
};
export function SelectedAgentCard(props: Props) {
return (
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
{props.storeAgent ? (
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-3">
{/* Left image */}
<SmartImage
src={props.storeAgent.agent_image[0]}
alt="Agent cover"
className="w-[350px] rounded-lg"
/>
{/* Right content */}
<div className="ml-3 flex flex-1 flex-col">
<div className="mb-2 flex flex-col items-start">
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-tight text-zinc-800">
{props.storeAgent.agent_name}
</span>
<span className="font-norma w-[292px] truncate font-sans text-xs text-zinc-600">
by {props.storeAgent.creator}
</span>
</div>
<div className="flex w-[292px] items-center justify-between">
<span className="truncate font-sans text-xs font-normal leading-tight text-zinc-600">
{props.storeAgent.runs.toLocaleString("en-US")} runs
</span>
<StarRating
className="font-sans text-xs font-normal leading-tight text-zinc-600"
starSize={12}
rating={props.storeAgent.rating || 0}
/>
</div>
</div>
</div>
) : (
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
)}
</div>
</div>
);
}

View File

@@ -1,9 +1,9 @@
import type { GraphMeta } from "@/lib/autogpt-server-api";
import type {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import type { InputValues } from "./types";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
export function computeInitialAgentInputs(
agent: GraphMeta | null,
@@ -21,7 +21,6 @@ export function computeInitialAgentInputs(
result[key] = existingInputs[key];
return;
}
// GraphIOSubSchema.default is typed as string, but server may return other primitives
const def = (subSchema as unknown as { default?: string | number }).default;
result[key] = def ?? "";
});
@@ -29,40 +28,20 @@ export function computeInitialAgentInputs(
return result;
}
export function getAgentCredentialsInputFields(agent: GraphMeta | null) {
const hasNoInputs =
!agent?.credentials_input_schema ||
typeof agent.credentials_input_schema !== "object" ||
!("properties" in agent.credentials_input_schema) ||
!agent.credentials_input_schema.properties;
if (hasNoInputs) return {};
return agent.credentials_input_schema.properties;
}
export function areAllCredentialsSet(
fields: Record<string, BlockIOCredentialsSubSchema>,
inputs: Record<string, CredentialsMetaInput | undefined>,
) {
const required = Object.keys(fields || {});
return required.every((k) => Boolean(inputs[k]));
}
type IsRunDisabledParams = {
agent: GraphMeta | null;
isRunning: boolean;
agentInputs: InputValues | null | undefined;
credentialsRequired: boolean;
credentialsSatisfied: boolean;
credentialsValid: boolean;
credentialsLoaded: boolean;
};
export function isRunDisabled({
agent,
isRunning,
agentInputs,
credentialsRequired,
credentialsSatisfied,
credentialsValid,
credentialsLoaded,
}: IsRunDisabledParams) {
const hasEmptyInput = Object.values(agentInputs || {}).some(
(value) => String(value).trim() === "",
@@ -71,7 +50,8 @@ export function isRunDisabled({
if (hasEmptyInput) return true;
if (!agent) return true;
if (isRunning) return true;
if (credentialsRequired && !credentialsSatisfied) return true;
if (!credentialsValid) return true;
if (!credentialsLoaded) return true;
return false;
}
@@ -81,13 +61,3 @@ export function getSchemaDefaultCredentials(
): CredentialsMetaInput | undefined {
return schema.default as CredentialsMetaInput | undefined;
}
export function sanitizeCredentials(
map: Record<string, CredentialsMetaInput | undefined>,
): Record<string, CredentialsMetaInput> {
const sanitized: Record<string, CredentialsMetaInput> = {};
for (const [key, value] of Object.entries(map)) {
if (value) sanitized[key] = value;
}
return sanitized;
}

View File

@@ -1,224 +1,66 @@
"use client";
import SmartImage from "@/components/__legacy__/SmartImage";
import { useOnboarding } from "../../../../providers/onboarding/onboarding-provider";
import OnboardingButton from "../components/OnboardingButton";
import { OnboardingHeader, OnboardingStep } from "../components/OnboardingStep";
import { OnboardingText } from "../components/OnboardingText";
import StarRating from "../components/StarRating";
import {
Card,
CardContent,
CardHeader,
CardTitle,
} from "@/components/__legacy__/ui/card";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
import type { InputValues } from "./types";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { cn } from "@/lib/utils";
import { Play } from "lucide-react";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/RunAgentInputs/RunAgentInputs";
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
import {
areAllCredentialsSet,
computeInitialAgentInputs,
getAgentCredentialsInputFields,
isRunDisabled,
getSchemaDefaultCredentials,
sanitizeCredentials,
} from "./helpers";
import { isRunDisabled } from "./helpers";
import { useOnboardingRunStep } from "./useOnboardingRunStep";
import { RunAgentHint } from "./components/RunAgentHint";
import { SelectedAgentCard } from "./components/SelectedAgentCard";
import { AgentOnboardingCredentials } from "./components/AgentOnboardingCredentials/AgentOnboardingCredentials";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
export default function Page() {
const { state, updateState, setStep } = useOnboarding(
undefined,
"AGENT_CHOICE",
);
const [showInput, setShowInput] = useState(false);
const [agent, setAgent] = useState<GraphMeta | null>(null);
const [storeAgent, setStoreAgent] = useState<StoreAgentDetails | null>(null);
const [runningAgent, setRunningAgent] = useState(false);
const [inputCredentials, setInputCredentials] = useState<
Record<string, CredentialsMetaInput | undefined>
>({});
const { toast } = useToast();
const router = useRouter();
const api = useBackendAPI();
const {
ready,
error,
showInput,
agent,
onboarding,
storeAgent,
runningAgent,
credentialsValid,
credentialsLoaded,
handleSetAgentInput,
handleRunAgent,
handleNewRun,
handleCredentialsChange,
handleCredentialsValidationChange,
handleCredentialsLoadingChange,
} = useOnboardingRunStep();
useEffect(() => {
setStep(5);
}, []);
useEffect(() => {
if (!state?.selectedStoreListingVersionId) {
return;
}
api
.getStoreAgentByVersionId(state?.selectedStoreListingVersionId)
.then((storeAgent) => {
setStoreAgent(storeAgent);
});
api
.getGraphMetaByStoreListingVersionID(state.selectedStoreListingVersionId)
.then((meta) => {
setAgent(meta);
const update = computeInitialAgentInputs(
meta,
(state.agentInput as unknown as InputValues) || null,
);
updateState({ agentInput: update });
});
}, [api, setAgent, updateState, state?.selectedStoreListingVersionId]);
const agentCredentialsInputFields = getAgentCredentialsInputFields(agent);
const credentialsRequired =
Object.keys(agentCredentialsInputFields || {}).length > 0;
const allCredentialsAreSet = areAllCredentialsSet(
agentCredentialsInputFields,
inputCredentials,
);
function setAgentInput(key: string, value: string) {
updateState({
agentInput: {
...state?.agentInput,
[key]: value,
},
});
if (error) {
return <ErrorCard responseError={error} />;
}
async function runAgent() {
if (!agent) {
return;
}
setRunningAgent(true);
try {
const libraryAgent = await api.addMarketplaceAgentToLibrary(
storeAgent?.store_listing_version_id || "",
);
const { id: runID } = await api.executeGraph(
libraryAgent.graph_id,
libraryAgent.graph_version,
state?.agentInput || {},
sanitizeCredentials(inputCredentials),
);
updateState({
onboardingAgentExecutionId: runID,
agentRuns: (state?.agentRuns || 0) + 1,
});
router.push("/onboarding/6-congrats");
} catch (error) {
console.error("Error running agent:", error);
toast({
title: "Error running agent",
description:
"There was an error running your agent. Please try again or try choosing a different agent if it still fails.",
variant: "destructive",
});
setRunningAgent(false);
}
}
const runYourAgent = (
<div className="ml-[104px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">Run your first agent</OnboardingText>
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
A &apos;run&apos; is when your agent starts working on a task
</span>
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
Click on <b>New Run</b> below to try it out
</span>
<div
onClick={() => {
setShowInput(true);
setStep(6);
updateState({
completedSteps: [
...(state?.completedSteps || []),
"AGENT_NEW_RUN",
],
});
}}
className={cn(
"mt-16 flex h-[68px] w-[330px] items-center justify-center rounded-xl border-2 border-violet-700 bg-neutral-50",
"cursor-pointer transition-all duration-200 ease-in-out hover:bg-violet-50",
)}
>
<svg
width="38"
height="38"
viewBox="0 0 32 32"
xmlns="http://www.w3.org/2000/svg"
>
<g stroke="#6d28d9" strokeWidth="1.2" strokeLinecap="round">
<line x1="16" y1="8" x2="16" y2="24" />
<line x1="8" y1="16" x2="24" y2="16" />
</g>
</svg>
<span className="ml-3 font-sans text-[19px] font-medium leading-normal text-violet-700">
New run
</span>
</div>
if (!ready) {
return (
<div className="flex flex-col gap-4">
<Skeleton className="h-10 w-full" />
<Skeleton className="h-10 w-full" />
</div>
</div>
);
);
}
return (
<OnboardingStep dotted>
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
{/* Agent card */}
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
{storeAgent ? (
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
{/* Left image */}
<SmartImage
src={storeAgent?.agent_image[0]}
alt="Agent cover"
imageContain
className="w-[350px] rounded-lg"
/>
{/* Right content */}
<div className="ml-2 flex flex-1 flex-col">
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
{storeAgent?.agent_name}
</span>
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
by {storeAgent?.creator}
</span>
<div className="mt-auto flex w-[292px] justify-between">
<span className="mt-1 truncate font-sans text-xs font-normal leading-tight text-zinc-600">
{storeAgent?.runs.toLocaleString("en-US")} runs
</span>
<StarRating
className="font-sans text-xs font-normal leading-tight text-zinc-600"
starSize={12}
rating={storeAgent?.rating || 0}
/>
</div>
</div>
</div>
) : (
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
)}
</div>
</div>
<div className="flex min-h-[80vh] items-center justify-center">
{/* Left side */}
<SelectedAgentCard storeAgent={storeAgent} />
<div className="w-[481px]" />
{/* Right side */}
{!showInput ? (
runYourAgent
<RunAgentHint handleNewRun={handleNewRun} />
) : (
<div className="ml-[104px] w-[481px] pl-5">
<div className="flex flex-col">
@@ -232,30 +74,7 @@ export default function Page() {
<span className="mt-4 text-base font-normal leading-normal text-zinc-600">
When you&apos;re done, click <b>Run Agent</b>.
</span>
{Object.entries(agentCredentialsInputFields || {}).map(
([key, inputSubSchema]) => (
<div key={key} className="mt-4">
<CredentialsInput
schema={inputSubSchema}
selectedCredentials={
inputCredentials[key] ??
getSchemaDefaultCredentials(inputSubSchema)
}
onSelectCredentials={(value) =>
setInputCredentials((prev) => ({
...prev,
[key]: value,
}))
}
siblingInputs={
(state?.agentInput || undefined) as
| Record<string, any>
| undefined
}
/>
</div>
),
)}
<Card className="agpt-box mt-4">
<CardHeader>
<CardTitle className="font-poppins text-lg">Input</CardTitle>
@@ -272,13 +91,23 @@ export default function Page() {
</label>
<RunAgentInputs
schema={inputSubSchema}
value={state?.agentInput?.[key]}
value={onboarding.state?.agentInput?.[key]}
placeholder={inputSubSchema.description}
onChange={(value) => setAgentInput(key, value)}
onChange={(value) => handleSetAgentInput(key, value)}
/>
</div>
),
)}
<AgentOnboardingCredentials
agent={agent}
siblingInputs={
(onboarding.state?.agentInput as Record<string, any>) ||
undefined
}
onCredentialsChange={handleCredentialsChange}
onValidationChange={handleCredentialsValidationChange}
onLoadingChange={handleCredentialsLoadingChange}
/>
</CardContent>
</Card>
<OnboardingButton
@@ -289,11 +118,12 @@ export default function Page() {
agent,
isRunning: runningAgent,
agentInputs:
(state?.agentInput as unknown as InputValues) || null,
credentialsRequired,
credentialsSatisfied: allCredentialsAreSet,
(onboarding.state?.agentInput as unknown as InputValues) ||
null,
credentialsValid,
credentialsLoaded,
})}
onClick={runAgent}
onClick={handleRunAgent}
icon={<Play className="mr-2" size={18} />}
>
Run agent

View File

@@ -0,0 +1,162 @@
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
import { StoreAgentDetails } from "@/app/api/__generated__/models/storeAgentDetails";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { computeInitialAgentInputs } from "./helpers";
import { InputValues } from "./types";
import {
useGetV2GetAgentByVersion,
useGetV2GetAgentGraph,
} from "@/app/api/__generated__/endpoints/store/store";
export function useOnboardingRunStep() {
const onboarding = useOnboarding(undefined, "AGENT_CHOICE");
const [showInput, setShowInput] = useState(false);
const [agent, setAgent] = useState<GraphMeta | null>(null);
const [storeAgent, setStoreAgent] = useState<StoreAgentDetails | null>(null);
const [runningAgent, setRunningAgent] = useState(false);
const [inputCredentials, setInputCredentials] = useState<
Record<string, CredentialsMetaInput>
>({});
const [credentialsValid, setCredentialsValid] = useState(true);
const [credentialsLoaded, setCredentialsLoaded] = useState(false);
const { toast } = useToast();
const router = useRouter();
const api = useBackendAPI();
const currentAgentVersion =
onboarding.state?.selectedStoreListingVersionId ?? "";
const storeAgentQuery = useGetV2GetAgentByVersion(currentAgentVersion, {
query: { enabled: !!currentAgentVersion },
});
const graphMetaQuery = useGetV2GetAgentGraph(currentAgentVersion, {
query: { enabled: !!currentAgentVersion },
});
useEffect(() => {
onboarding.setStep(5);
}, []);
useEffect(() => {
if (storeAgentQuery.data && storeAgentQuery.data.status === 200) {
setStoreAgent(storeAgentQuery.data.data);
}
}, [storeAgentQuery.data]);
useEffect(() => {
if (
graphMetaQuery.data &&
graphMetaQuery.data.status === 200 &&
onboarding.state
) {
const graphMeta = graphMetaQuery.data.data as GraphMeta;
setAgent(graphMeta);
const update = computeInitialAgentInputs(
graphMeta,
(onboarding.state.agentInput as unknown as InputValues) || null,
);
onboarding.updateState({ agentInput: update });
}
}, [graphMetaQuery.data]);
function handleNewRun() {
if (!onboarding.state) return;
setShowInput(true);
onboarding.setStep(6);
onboarding.updateState({
completedSteps: [
...(onboarding.state.completedSteps || []),
"AGENT_NEW_RUN",
],
});
}
function handleSetAgentInput(key: string, value: string) {
if (!onboarding.state) return;
onboarding.updateState({
agentInput: {
...onboarding.state.agentInput,
[key]: value,
},
});
}
async function handleRunAgent() {
if (!agent || !storeAgent || !onboarding.state) {
toast({
title: "Error getting agent",
description:
"Either the agent is not available or there was an error getting it.",
variant: "destructive",
});
return;
}
setRunningAgent(true);
try {
const libraryAgent = await api.addMarketplaceAgentToLibrary(
storeAgent?.store_listing_version_id || "",
);
const { id: runID } = await api.executeGraph(
libraryAgent.graph_id,
libraryAgent.graph_version,
onboarding.state.agentInput || {},
inputCredentials,
);
onboarding.updateState({
onboardingAgentExecutionId: runID,
agentRuns: (onboarding.state.agentRuns || 0) + 1,
});
router.push("/onboarding/6-congrats");
} catch (error) {
console.error("Error running agent:", error);
toast({
title: "Error running agent",
description:
"There was an error running your agent. Please try again or try choosing a different agent if it still fails.",
variant: "destructive",
});
setRunningAgent(false);
}
}
return {
ready: graphMetaQuery.isSuccess && storeAgentQuery.isSuccess,
error: graphMetaQuery.error || storeAgentQuery.error,
agent,
onboarding,
showInput,
storeAgent,
runningAgent,
credentialsValid,
credentialsLoaded,
handleSetAgentInput,
handleRunAgent,
handleNewRun,
handleCredentialsChange: setInputCredentials,
handleCredentialsValidationChange: setCredentialsValid,
handleCredentialsLoadingChange: (v: boolean) => setCredentialsLoaded(!v),
};
}

View File

@@ -46,7 +46,7 @@ export default function StarRating({
)}
>
{/* Display numerical rating */}
<span className="mr-1 mt-1">{roundedRating}</span>
<span className="mr-1 mt-0.5">{roundedRating}</span>
{/* Display stars */}
{stars.map((starType, index) => {

View File

@@ -19,6 +19,7 @@ import WalletRefill from "./components/WalletRefill";
import { OnboardingStep } from "@/lib/autogpt-server-api";
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
import { WalletIcon } from "@phosphor-icons/react";
import { useGetFlag, Flag } from "@/services/feature-flags/use-get-flag";
export interface Task {
id: OnboardingStep;
@@ -40,6 +41,7 @@ export interface TaskGroup {
export default function Wallet() {
const { state, updateState } = useOnboarding();
const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT);
const groups = useMemo<TaskGroup[]>(() => {
return [
@@ -379,9 +381,7 @@ export default function Wallet() {
</div>
<ScrollArea className="max-h-[85vh] overflow-y-auto">
{/* Top ups */}
{process.env.NEXT_PUBLIC_SHOW_BILLING_PAGE === "true" && (
<WalletRefill />
)}
{isPaymentEnabled && <WalletRefill />}
{/* Tasks */}
<p className="mx-1 my-3 font-sans text-xs font-normal text-zinc-400">
Complete the following tasks to earn more credits!

View File

@@ -1,16 +1,23 @@
"use client";
import { isServerSide } from "@/lib/utils/is-server-side";
import { useEffect, useState } from "react";
import { Text } from "@/components/atoms/Text/Text";
import { Card } from "@/components/atoms/Card/Card";
import { WaitlistErrorContent } from "@/components/auth/WaitlistErrorContent";
import { isWaitlistError } from "@/app/api/auth/utils";
import { useRouter } from "next/navigation";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { environment } from "@/services/environment";
export default function AuthErrorPage() {
const [errorType, setErrorType] = useState<string | null>(null);
const [errorCode, setErrorCode] = useState<string | null>(null);
const [errorDescription, setErrorDescription] = useState<string | null>(null);
const router = useRouter();
useEffect(() => {
// This code only runs on the client side
if (!isServerSide()) {
if (!environment.isServerSide()) {
const hash = window.location.hash.substring(1); // Remove the leading '#'
const params = new URLSearchParams(hash);
@@ -23,15 +30,45 @@ export default function AuthErrorPage() {
}, []);
if (!errorType && !errorCode && !errorDescription) {
return <div>Loading...</div>;
return (
<div className="flex h-screen items-center justify-center">
<Text variant="body">Loading...</Text>
</div>
);
}
// Check if this is a waitlist/not allowed error
const isWaitlistErr = isWaitlistError(errorCode, errorDescription);
if (isWaitlistErr) {
return (
<div className="flex h-screen items-center justify-center">
<Card className="w-full max-w-md p-8">
<WaitlistErrorContent onBackToLogin={() => router.push("/login")} />
</Card>
</div>
);
}
// Use ErrorCard for consistent error display
const errorMessage = errorDescription
? `${errorDescription}. If this error persists, please contact support at contact@agpt.co`
: "An authentication error occurred. Please contact support at contact@agpt.co";
return (
<div>
<h1>Authentication Error</h1>
{errorType && <p>Error Type: {errorType}</p>}
{errorCode && <p>Error Code: {errorCode}</p>}
{errorDescription && <p>Error Description: {errorDescription}</p>}
<div className="flex h-screen items-center justify-center p-4">
<div className="w-full max-w-md">
<ErrorCard
responseError={{
message: errorMessage,
detail: errorCode
? `Error code: ${errorCode}${errorType ? ` (${errorType})` : ""}`
: undefined,
}}
context="authentication"
onRetry={() => router.push("/login")}
/>
</div>
</div>
);
}

View File

@@ -0,0 +1,11 @@
import { RunGraph } from "./components/RunGraph";
export const BuilderActions = () => {
return (
<div className="absolute bottom-4 left-[50%] z-[100] -translate-x-1/2">
{/* TODO: Add Agent Output */}
<RunGraph />
{/* TODO: Add Schedule run button */}
</div>
);
};

View File

@@ -0,0 +1,32 @@
import { Button } from "@/components/atoms/Button/Button";
import { PlayIcon } from "lucide-react";
import { useRunGraph } from "./useRunGraph";
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { useShallow } from "zustand/react/shallow";
import { StopIcon } from "@phosphor-icons/react";
import { cn } from "@/lib/utils";
export const RunGraph = () => {
const { runGraph, isSaving } = useRunGraph();
const isGraphRunning = useGraphStore(
useShallow((state) => state.isGraphRunning),
);
return (
<Button
variant="primary"
size="large"
className={cn(
"relative min-w-44 border-none bg-gradient-to-r from-purple-500 to-pink-500 text-lg",
)}
onClick={() => runGraph()}
>
{!isGraphRunning && !isSaving ? (
<PlayIcon className="mr-1 size-5" />
) : (
<StopIcon className="mr-1 size-5" />
)}
{isGraphRunning || isSaving ? "Stop Agent" : "Run Agent"}
</Button>
);
};

View File

@@ -0,0 +1,62 @@
import { usePostV1ExecuteGraphAgent } from "@/app/api/__generated__/endpoints/graphs/graphs";
import { useNewSaveControl } from "../../../NewControlPanel/NewSaveControl/useNewSaveControl";
import { useToast } from "@/components/molecules/Toast/use-toast";
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
import { useShallow } from "zustand/react/shallow";
export const useRunGraph = () => {
const { onSubmit: onSaveGraph, isLoading: isSaving } = useNewSaveControl({
showToast: false,
});
const { toast } = useToast();
const setIsGraphRunning = useGraphStore(
useShallow((state) => state.setIsGraphRunning),
);
const [{ flowID, flowVersion }, setQueryStates] = useQueryStates({
flowID: parseAsString,
flowVersion: parseAsInteger,
flowExecutionID: parseAsString,
});
const { mutateAsync: executeGraph } = usePostV1ExecuteGraphAgent({
mutation: {
onSuccess: (response) => {
const { id } = response.data as GraphExecutionMeta;
setQueryStates({
flowExecutionID: id,
});
},
onError: (error) => {
setIsGraphRunning(false);
toast({
title: (error.detail as string) ?? "An unexpected error occurred.",
description: "An unexpected error occurred.",
variant: "destructive",
});
},
},
});
const runGraph = async () => {
setIsGraphRunning(true);
await onSaveGraph(undefined);
// Todo : We need to save graph which has inputs and credentials inputs
await executeGraph({
graphId: flowID ?? "",
graphVersion: flowVersion || null,
data: {
inputs: {},
credentials_inputs: {},
},
});
};
return {
runGraph,
isSaving,
};
};

Some files were not shown because too many files have changed in this diff Show More