Compare commits

..

75 Commits

Author SHA1 Message Date
Abhimanyu Yadav
c3d92a4e06 Merge branch 'dev' into abhi-9274/postgres-integration 2025-05-05 13:44:14 +05:30
Nicholas Tindle
afb66f75ec fix: disable google sheets in prod based on oauth review (#9906)
<!-- Clearly explain the need for these changes: -->

Our oauth review wants us to drop this in favor of a diff scope that
will require additional work

### Changes 🏗️
Disables the oauth sheets scopes in prod

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] set env locally
2025-05-02 19:40:51 +00:00
Krzysztof Czerwinski
59ec61ef98 feat(platform): Onboarding design&UX update (#9905)
A collection of updates regarding onboarding and wallet.

### Changes 🏗️

- `try-except` instead of `if` when rewarding (skip unnecessary db call)
- Make external services question onboarding step optional
- Add `SmartImage` component to lazy load images with pulse animation
and use it throughout onboarding
- Use store agent name instead of graph graph name (run page)
- Fix some images breaking layout on the agent card (run page)
- Center agent card vertically and horizontally (center on the left half
of page) (run page)
- Delay and tweak confetti when opening wallet and when task finished
(wallet)
- Flash wallet when credits change value
- Make tutorial video grayscale on completed steps (wallet)
- Fix confetti triggering on page refresh (wallet)
- Redirect to agent run page instead of Library after onboarding
- Expand task groups by default (wallet) - this means tutorial videos
are visible by default

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Services step is optional and skipping it doesn't break onboarding
  - [x] `SmartImage` works properly
  - [x] Agent card is aligned properly, including on page scroll
  - [x] Wallet flash when credits value change
  - [x] User is redirected to the agent runs page after onboarding
2025-05-02 14:42:01 +00:00
Zamil Majdy
d7077b5161 feat(backend): Continue instead of retrying aborted/broken agent execution (#9903)
Currently, the agent/graph execution engine is consuming the execution
queue and acknowledges the message after fully completing its execution
or failing it.

However, in the case of the agent executor failing due to a
hardware/resource issue, or the executor did not manage to acknowledge
the execution message. Another agent executor will pick it up and start
the execution again from the beginning.

The scope of this PR is to make the next executor pick up the next work
to continue the pre-existing execution instead of starting it all over
from the beginning.

### Changes 🏗️

* Removed `start_node_execs` from `GraphExecutionEntry`
* Populate the starting graph node from the DB query instead (fetching
Running & Queued node executions).
* Removed `get_incomplete_node_executions` from DB manager.
* Use get_node_executions with a status filter instead.
* Allow graph execution to end in non-FAILED/COMPLETED status, e.g, when
the executor is interrupted, it should be stuck in the running status,
and let other executors continue the task.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Run an agent, stop the executor midway, re-reun the executor, the
execution should be continued instead of restarted.
2025-05-01 16:02:03 +00:00
Zamil Majdy
475c5a5cc3 fix(backend): Avoid executing any agent with zero balance (#9901)
### Changes 🏗️

* Avoid executing any agent with a zero balance.
* Make node execution count global across agents for a single user.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Run agents by tweaking the `execution_cost_count_threshold` &
`execution_cost_per_threshold` values.
2025-05-01 15:11:38 +00:00
Zamil Majdy
86d5cfe60b feat(backend): Support flexible RPC client (#9842)
Using sync code in the async route often introduces a blocking
event-loop code that impacts stability.

The current RPC system only provides a synchronous client to call the
service endpoints.
The scope of this PR is to provide an entirely decoupled signature
between client and server, allowing the client can mix & match async &
sync options on the client code while not changing the async/sync nature
of the server.

### Changes 🏗️

* Add support for flexible async/sync RPC client.
* Migrate scheduler client to all-async client.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Scheduler route test.
  - [x] Modified service_test.py
  - [x] Run normal agent executions
2025-05-01 04:38:06 +00:00
Bently
602f887623 feat(frontend): fix admin add dollars (#9898)
Fixes the admin add dollars, in the ``add-money-button.tsx`` file, in
the handleApproveSubmit action it was trying to use formatCredits for
the value which is wrong, this fix changes it

```diff
 <form action={handleApproveSubmit}>
   <input type="hidden" name="id" value={userId} />
   <input
     type="hidden"
     name="amount"
-    value={formatCredits(Number(dollarAmount))}
+    value={Math.round(parseFloat(dollarAmount) * 100)}
   />
```
i was able to add $1, $0.10 and $0.01

![image](https://github.com/user-attachments/assets/3a3126c2-5f17-4c9b-8657-4372332a0ea3)
2025-04-30 17:24:26 +00:00
Bentlybro
1edde778c5 Merge branch 'master' into dev 2025-04-30 16:46:50 +01:00
Zamil Majdy
3526986f98 fix(backend): Failing test on a new Pydantic version (#9897)
```
FAILED test/model_test.py::test_agent_preset_from_db - pydantic_core._pydantic_core.ValidationError: 1 validation error for AgentNodeExecutionInputOutput

E       pydantic_core._pydantic_core.ValidationError: 1 validation error for AgentNodeExecutionInputOutput
E       data
E         JSON input should be string, bytes or bytearray [type=json_type, input_value=Json, input_type=Json]
E           For further information visit https://errors.pydantic.dev/2.11/v/json_type
```

### Changes 🏗️

Manually creating a Prisma model often breaks, and we have such an
instance in the test.
This PR fixes the test to make the new Pydantic happy.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] CI
2025-04-30 13:59:17 +00:00
Nicholas Tindle
04c4340ee3 feat(frontend,backend): user spending admin dashboard (#9751)
<!-- Clearly explain the need for these changes: -->
We need a way to refund people who spend money on agents wihout making
manual db actions

### Changes 🏗️
- Adds a bunch for refunding users
- Adds reasons and admin id for actions
- Add admin to db manager
- Add UI for this for the admin panel
- Clean up pagination controls
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test by importing dev db as baseline
- [x] Add transactions on top for "refund", and make sure all existing
transactions work

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2025-04-29 17:39:25 +00:00
Zamil Majdy
9fa62c03f6 feat(backend): Improve cancel execution reliability (#9889)
When an executor dies, an ongoing execution will not be retried and will
just stuck in the running status.
This change avoids such a scenario by allowing an execution of an entry
that is not in QUEUED status with the low-probability risk of double
execution.

### Changes 🏗️

* Allow non-QUEUED status to be re-executed.
* Improve cleanup of node & graph executor.
* Make a cancellation request consumption a separate thread to avoid
being blocked by other messages.
* Remove unused retry loop on the execution manager.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Run agent, kill the server, re-run it, agent restarted.
2025-04-29 17:06:03 +00:00
Mareddy Lohith Reddy
d5dc687484 fix: handle empty 204 responses in SendWebRequestBlock (#9887)
<!-- Clearly explain the need for these changes: -->
This PR fixes [Issue
#9883](https://github.com/Significant-Gravitas/AutoGPT/issues/9883),
where the SendWebRequestBlock crashes when receiving a 204 No Content
response, such as when posting to a Discord webhook. The fix ensures
that empty responses are handled gracefully, and the block does not
crash.

### Changes 🏗️
- Added a check to handle empty HTTP responses (like 204 status) in
SendWebRequestBlock
- Fallback to empty string or None if there is no response content
- Prevents server errors when parsing non-existent response bodies

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Send a POST request to an endpoint that returns 204 No Content
  - [x]  Confirm that SendWebRequestBlock handles it without crashing
  - [x] Confirm that regular 200 OK JSON responses still work

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Lohith-11 <lohithr011@gamil.com>
Co-authored-by: Toran Bruce Richards <toran.richards@gmail.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-04-28 19:16:04 +00:00
Japh
fb5ce0a16d Add Note to "Getting Started" page for Raspberry Pi 5 page size issue (#9888)
Add Note to "Getting Started" page for Raspberry Pi 5 page size issue
with `supabase-vector` that prevents `docker compose up` from running
successfully.

<!-- Clearly explain the need for these changes: -->

### Changes 🏗️

- Added a Note to the "Getting Started" page that explains a change in
Raspberry Pi OS for Raspberry Pi 5s, and how to revert the change to
avoid an issue running the backend on Docker.

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] No code changes

#### For configuration changes:
- [x] No configuration changes

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
2025-04-28 19:07:44 +00:00
Nicholas Tindle
a1f17ca797 fix: use subheading for agent info not description (#9891)
<!-- Clearly explain the need for these changes: -->
we oopsed and used the wrong attribute for short desc
### Changes 🏗️
Uses sub heading instead now
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] check the expected text shows
2025-04-28 18:38:43 +00:00
Nicholas Tindle
8fdfd75cc4 feat: allow admins to download agents for review (#9881)
<!-- Clearly explain the need for these changes: -->
for admins to approve agents for the marketplace, we need to be able to
run them. this is a quick workaround for downloading them so you can put
them in your marketplace to check

### Changes 🏗️
- clones various endpoints related to downloading into an admin side
with logging, and admin checks
- adds download button and removes open in builder action
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] Test downloading agents from local marketplace
2025-04-28 17:58:23 +00:00
Abhimanyu Yadav
5b5b2043e8 fix(frontend): Add support to optional multiselect (#9885)
- fix #9882 

we’re currently using optional multi select, and it’s working great.
We’re able to correctly determine the data type for it. However, there’s
a small issue. We’re not using the correct subSchema that is inside
anyOf on the multi select input. This is why we’re getting the problem
on the Twitter block. It’s the only one that’s using this type of input,
so it’s the only one that’s affected.

![Screenshot 2025-04-26 at 5 39
51 PM](https://github.com/user-attachments/assets/834d64d8-84dc-4dbd-a03a-df03172ecee5)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-28 15:18:29 +00:00
Nicholas Tindle
7d83f1db05 feat(block): bring back PrintConsoleBlock (#9850)
### Changes 🏗️

Bring back PrintConsoleBlock

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Print console block

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2025-04-28 09:55:57 +00:00
Zamil Majdy
f07696e3c1 fix(backend): Fix top-up with zero transaction flow (#9886)
The transaction with zero payment amount will not generate a payment ID,
so the checkout failed for this scenario.

### Changes 🏗️

Don't use payment id as transaction key on top-up with zero payment
amount.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Top-up with stripe coupon
2025-04-27 14:22:08 +00:00
Zamil Majdy
96a173a85f fix(backend): Avoid releasing lock that is no longer owned by the current thread (#9878)
There are instances of node executions that were failed and end up stuck
in the RUNNING status due to the execution failed to release the lock:
```
2025-04-24 20:53:31,573 INFO  [ExecutionManager|uid:25eba2d1-e9c1-44bc-88c7-43e0f4fbad5a|gid:01f8c315-c163-4dd1-a8a0-d396477c5a9f|nid:f8bf84ae-b1f0-4434-8f04-80f43852bc30]|geid:2e1b35c6-0d2f-4e97-adea-f6fe0d9965d0|neid:590b29ea-63ee-4e24-a429-de5a3e191e72|-] Failed node execution 590b29ea-63ee-4e24-a429-de5a3e191e72: Cannot release a lock that's no longer owned
```

### Changes 🏗️

Check the ownership of the lock before releasing.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Existing CI tests.

(cherry picked from commit ef022720d5)
2025-04-25 23:55:18 +07:00
sentry-autofix[bot]
9715ea5313 fix: handle token limits and estimate token count for llm calls (#9880)
👋 Hi there! This PR was automatically generated by Autofix 🤖

This fix was triggered by Toran Bruce Richards.

Fixes
[AUTOGPT-SERVER-1ZY](https://sentry.io/organizations/significant-gravitas/issues/6386687527/).
The issue was that: `llm_call` calculates `max_tokens` without
considering `input_tokens`, causing OpenRouter API errors when the
context window is exceeded.

- Implements a function `estimate_token_count` to estimate the number of
tokens in a list of messages.
- Calculates available tokens based on the context window, estimated
input tokens, and user-defined max tokens.
- Adjusts `max_tokens` for LLM calls to prevent exceeding context window
limits.
- Reduces `max_tokens` by 15% and retries if a token limit error is
encountered during LLM calls.

If you have any questions or feedback for the Sentry team about this
fix, please email [autofix@sentry.io](mailto:autofix@sentry.io) with the
Run ID: 32838.

---------

Co-authored-by: sentry-autofix[bot] <157164994+sentry-autofix[bot]@users.noreply.github.com>
Co-authored-by: Krzysztof Czerwinski <kpczerwinski@gmail.com>
2025-04-25 13:45:47 +00:00
Zamil Majdy
ef022720d5 fix(backend): Avoid releasing lock that is no longer owned by the current thread (#9878)
There are instances of node executions that were failed and end up stuck
in the RUNNING status due to the execution failed to release the lock:
```
2025-04-24 20:53:31,573 INFO  [ExecutionManager|uid:25eba2d1-e9c1-44bc-88c7-43e0f4fbad5a|gid:01f8c315-c163-4dd1-a8a0-d396477c5a9f|nid:f8bf84ae-b1f0-4434-8f04-80f43852bc30]|geid:2e1b35c6-0d2f-4e97-adea-f6fe0d9965d0|neid:590b29ea-63ee-4e24-a429-de5a3e191e72|-] Failed node execution 590b29ea-63ee-4e24-a429-de5a3e191e72: Cannot release a lock that's no longer owned
```

### Changes 🏗️

Check the ownership of the lock before releasing.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Existing CI tests.
2025-04-25 07:39:10 +00:00
Zamil Majdy
4ddb206f86 feat(frontend): Add billing page toggle (#9877)
### Changes 🏗️

Provide a system toggle for disabling the billing page:
NEXT_PUBLIC_SHOW_BILLING_PAGE

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Toggle `NEXT_PUBLIC_SHOW_BILLING_PAGE` value.
2025-04-24 19:33:20 +00:00
Zamil Majdy
91f34966c8 fix(block): Fix Smart Decision Block missing input beads & incompability with input in special characters (#9875)
Smart Decision Block was not able to work with sub agent with custom
name input & the bead were not properly propagated in the execution UI.
The scope of this PR is fixing it.

### Changes 🏗️

* Introduce an easy to parse format of tool edge:
`{tool}_^_{func}_~_{arg}`. Graph using SmartDecisionBlock needs to be
re-saved before execution to work.
* Reduce cluttering on a smart decision block logic.
* Fix beads not being shown for a smart decision block tool calling.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Execute an SDM with some special character input as a tool

<img width="672" alt="image"
src="https://github.com/user-attachments/assets/873556b3-c16a-4dd1-ad84-bc86c636c406"
/>
2025-04-24 19:24:41 +00:00
Krzysztof Czerwinski
11a69170b5 feat(frontend): Update "Edit a copy" modal and buttons (#9876)
Update "Edit a copy" modal text when copying marketplace agent in
Library. Update agent action buttons to reflect the design accurately.

### Changes 🏗️

- Update modal text
- Disable copying owned agents (only marketplace allowed)
- `Open in Builder` -> `Customize agent`
- Disabled `Customize agent` instead of hiding
- Change `Delete agent` to non-destructive design

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  - [ ] ...
2025-04-24 16:30:43 +00:00
Krzysztof Czerwinski
0675a41e42 fix(backend): Strip secrets, credentials when forking agent (#9874)
Strip secrets, credentials when forking agent

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  - [ ] ...
2025-04-24 15:09:41 +00:00
Bentlybro
56ce1a0c1c Merge branch 'master' into dev 2025-04-24 14:21:34 +01:00
Zamil Majdy
7fbe135ec8 feat(backend): Expose execution prometheus metrics (#9866)
Currently, we have no visibility on the state of the execution manager,
the scope of this PR is to open up the observability of it by exposing
Prometheus metrics.

### Changes 🏗️

Re-use the execution manager port to expose the Prometheus metrics.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Hit /metrics on 8002 port
2025-04-24 07:48:38 +00:00
Zamil Majdy
eb6a0b34e1 feat(backend): Use forkserver on process creation if possible (#9864)
### Changes 🏗️

Set process starting mode to forkserver instead of spawn, if possible,
for performance benefits.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Existing tests
2025-04-24 07:36:36 +00:00
Zamil Majdy
1e3236a041 feat(backend): Add retry on executor process initialization (#9865)
Executor process initialization can fail and cause this error:
```
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore
```

### Changes 🏗️

Add retry to reduce the chance of the initialization error to happen.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Existing tests
2025-04-23 21:24:17 +00:00
Krzysztof Czerwinski
160a622ba4 feat(platform): Forking agent in Library (#9870)
This PR introduces copying agents feature in the Library. Users can copy
and download their library agents but they can edit only the ones they
own (included copied ones).

### Changes 🏗️

- DB migration: add relation in `AgentGraph`: `forked_from_id` and
`forked_from_version`
- Add `fork_graph` function that makes a hardcopy of agent graph and its
nodes (all with new ids)
- Add `fork_library_agent` that copies library agent and its graph for a
user
- Add endpoint `/library/agents/{libraryAgentId}/fork`
- Add UI to `library/agents/[id]/page.tsx`: `Edit a copy` button with
dialog confirmation

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Agent can be copied, edited and runs
2025-04-23 16:28:42 +00:00
Toran Bruce Richards
e2a226dc49 Update repo-close-stale-issues.yml 2025-04-23 14:51:18 +01:00
Zamil Majdy
5047e99fd1 fix(frontend): Hide Google Maps Key ID filter (#9861)
### Changes 🏗️


![image](https://github.com/user-attachments/assets/d6b9f971-d914-4ff1-9319-a903707a2c72)

Hide Google Maps system id key on the frontend UI.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan
2025-04-22 16:50:05 +00:00
Krzysztof Czerwinski
c80d357149 feat(frontend): Use route groups (#9855)
Navbar sometimes disappears outside `/onboarding`.

### Changes 🏗️

This PR solves the problem of disappearing Navbar outside `/onboarding`
by introducing `app/(platform)` route group.

- Move all routes requiring Navbar to `app/(platform)`
- Move `<Navbar>` to `app/(platform)/layout.tsx`
- Move `/onboarding` to `app/(no-navbar/`
- Remove pathname injection to header from middleware and stop relying
on it to hide the navbar

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Common routes work properly
2025-04-22 09:10:12 +00:00
Zamil Majdy
20d39f6d44 fix(platform): Fix Google Maps API Key setting through env (#9848)
Setting the Google Maps API through the API has never worked on the
platform.

### Changes 🏗️

Set the default api key from the environment variable.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Test GoogleMapsBlock
2025-04-22 03:00:47 +07:00
Bently
d5b82c01e0 feat(backend): Adds latest llm models (#9856)
This PR adds the following models:
OpenAI's O3: https://platform.openai.com/docs/models/o3
OpenAI's GPT 4.1: https://platform.openai.com/docs/models/gpt-4.1
Anthropics Claude 3.7: https://www.anthropic.com/news/claude-3-7-sonnet
Googles gemini 2.5 pro:
https://openrouter.ai/google/gemini-2.5-pro-preview-03-25
2025-04-21 19:26:21 +00:00
Abhimanyu Yadav
69b8d96516 fix(library/run): Replace credits to cents (#9845)
Replacing credits with cents (100 credits = 1$).

I haven’t touched anything internally, just changed the UI.

Everything is working great.

On the frontend, there’s no other place where we use credits instead of
dollars.

![Screenshot 2025-04-19 at 11 36
00 AM](https://github.com/user-attachments/assets/de799b5c-094e-4c96-a7da-273ce60b2125)
<img width="1503" alt="Screenshot 2025-04-19 at 11 33 24 AM"
src="https://github.com/user-attachments/assets/87d7e218-f8f5-4e2e-92ef-70c81735db6b"
/>
2025-04-21 12:31:48 +00:00
Krzysztof Czerwinski
67af77e179 fix((backend): Fix migrate llm models in existing agents (#9810)
https://github.com/Significant-Gravitas/AutoGPT/pull/9452 was throwing
`operator does not exist: text ? unknown` on deployed dev and so the
function call was commented as a hotfix.
This PR fixes and re-enables the llm model migration function.

### Changes 🏗️

- Uncomment and fix `migrate_llm_models` function

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Migrate nodes with non-existing models
  - [x] Don't migrate nodes without any model or with correct models

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2025-04-19 12:52:36 +00:00
Abhimanyu Yadav
2a92970a5f fix(marketplace/library): Removing white borders from Avatar (#9818)
There are some white borders around the avatar in the store card, but
they are not present in the design, so I'm removing them.

![Screenshot 2025-04-15 at 3 58
05 PM](https://github.com/user-attachments/assets/f8c98076-9cc3-46f1-b4f3-41d4e48f6127)
2025-04-19 05:36:36 +00:00
Zamil Majdy
9052ee7b95 fix(backend): Clear RabbitMQ connection cache on execution-manager retry 2025-04-19 07:50:04 +02:00
Zamil Majdy
c783f64b33 fix(backend): Handle add execution API request failure (#9838)
There are cases where the publishing agent execution is failing, making
the agent execution appear to be stuck in a queue, but the execution has
never been in a queue in the first place.

### Changes 🏗️

On publishing failure, we set the graph & starting node execution status
to FAILED and let the UI bubble up the error so the user can try again.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Normal add execution flow
2025-04-18 18:35:43 +00:00
Zamil Majdy
055a231aed feat(backend): Add retry mechanism for pika publish_message (#9839)
For unknown reason publishing message can fail sometimes due to the
connection being broken:
MessageQueue suddenly unavailable, connection simply broke, connection
being reset, etc.

### Changes 🏗️

Adding a tenacity retry on AMQP or ConnectionError, which hopefully can
alleviate the issue.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Simple add execution
2025-04-18 17:56:27 +00:00
Reinier van der Leer
417d7732af feat(platform/library): Add credentials UX on /library/agents/[id] (#9789)
- Resolves #9771
- ... in a non-persistent way, so it won't work for webhook-triggered
agents
    For webhooks: #9541

### Changes 🏗️

Frontend:
- Add credentials inputs in Library "New run" screen (based on
`graph.credentials_input_schema`)
- Refactor `CredentialsInput` and `useCredentials` to not rely on XYFlow
context

- Unsplit lists of saved credentials in `CredentialsProvider` state

- Move logic that was being executed at component render to `useEffect`
hooks in `CredentialsInput`

Backend:
- Implement logic to aggregate credentials input requirements to one per
provider per graph
- Add `BaseGraph.credentials_input_schema` (JSON schema) computed field
    Underlying added logic:
- `BaseGraph._credentials_input_schema` - makes a `BlockSchema` from a
graph's aggregated credentials inputs
- `BaseGraph.aggregate_credentials_inputs()` - aggregates a graph's
nodes' credentials inputs using `CredentialsFieldInfo.combine(..)`
- `BlockSchema.get_credentials_fields_info() -> dict[str,
CredentialsFieldInfo]`
- `CredentialsFieldInfo` model (created from
`_CredentialsFieldSchemaExtra`)

- Implement logic to inject explicitly passed credentials into graph
execution
  - Add `credentials_inputs` parameter to `execute_graph` endpoint
- Add `graph_credentials_input` parameter to
`.executor.utils.add_graph_execution(..)`
  - Implement `.executor.utils.make_node_credentials_input_map(..)`
  - Amend `.executor.utils.construct_node_execution_input`
  - Add `GraphExecutionEntry.node_credentials_input_map` attribute
  - Amend validation to allow injecting credentials
    - Amend `GraphModel._validate_graph(..)`
    - Amend `.executor.utils._validate_node_input_credentials`
- Add `node_credentials_map` parameter to
`ExecutionManager.add_execution(..)`
    - Amend execution validation to handle side-loaded credentials
    - Add `GraphExecutionEntry.node_execution_map` attribute
- Add mechanism to inject passed credentials into node execution data
- Add credentials injection mechanism to node execution queueing logic
in `Executor._on_graph_execution(..)`

- Replace boilerplate logic in `v1.execute_graph` endpoint with call to
existing `.executor.utils.add_graph_execution(..)`
- Replace calls to `.server.routers.v1.execute_graph` with
`add_graph_execution`

Also:
- Address tech debt in `GraphModel._validate_gaph(..)`
- Fix type checking in `BaseGraph._generate_schema(..)`

#### TODO
- [ ] ~~Make "Run again" work with credentials in
`AgentRunDetailsView`~~
- [ ] Prohibit saving a graph if it has nodes with missing discriminator
value for discriminated credentials inputs

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...
2025-04-18 14:27:13 +00:00
Krzysztof Czerwinski
f16a398a8e feat(frontend): Update completed task group design in Wallet (#9820)
This redesigns how the task group is displayed when finished for both
expanded and folded state.

### Changes 🏗️

- Folded state now displays `Done` badge and hides tasks
- Expanded state shows only task names and hides details and video

Screenshot:
1. Expanded unfinished group
2. Expanded finished group
3. Folded finished group

<img width="463" alt="Screenshot 2025-04-15 at 2 05 31 PM"
src="https://github.com/user-attachments/assets/40152073-fc0e-47c2-9fd4-a6b0161280e6"
/>

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Finished group displays correctly
  - [x] Unfinished group displays correctly
2025-04-18 09:45:35 +00:00
Krzysztof Czerwinski
e8bbd945f2 feat(frontend): Wallet top-up and auto-refill (#9819)
### Changes 🏗️

- Add top-up and auto-refill tabs in the Wallet
- Add shadcn `tabs` component
- Disable increase/decrease spinner buttons on number inputs across
Platform (moved css from `customnode.css` to `globals.css`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Incorrect values are detected properly
  - [x] Top-up works
  - [x] Setting auto-refill works
2025-04-18 09:44:54 +00:00
Krzysztof Czerwinski
d1730d7b1d fix(frontend): Fix onboarding agent execution (#9822)
Onboarding executes original agent graph directly without waiting for
marketplace agent to be added to user library.

### Changes 🏗️

- Execute library agent after it's already added to library

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Onboarding agent executes properly
2025-04-18 09:36:40 +00:00
Krzysztof Czerwinski
8ea64327a1 fix(backend): Fix array types in database (#9828)
Array fields in `schema.prisma` are non-nullable, but generated
migrations don’t add `NOT NULL` constraints. This causes existing rows
to get `NULL` values when new array columns are added, breaking schema
expectations and leading to bugs.

### Changes 🏗️

- Backfill all `NULL` rows on non-nullable array columns to empty arrays
- Set `NOT NULL` constraint on all array columns

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Existing `NULL` rows are properly backfilled
  - [x] Existing arrays are not set to default empty arrays
  - [x] Affected columns became non-nullable in the db
2025-04-18 07:43:54 +00:00
Bently
3cf30c22fb update(docs): Remove outdated submodule command from docs (#9836)
### Changes 🏗️

Updates to the setup docs to remove the old unneeded ``git submodule
update --init --recursive --progress`` command + some other small tweaks
around it
2025-04-17 16:45:07 +00:00
Reinier van der Leer
05c670eef9 fix(frontend/library): Prevent execution updates mixing between library agents (#9835)
If the websocket doesn't disconnect when the user switches to viewing a
different agent, they aren't unsubscribed. If execution updates *from a
different agent* are adopted into the page state, that can cause
crashes.

### Changes 🏗️

- Filter incoming execution updates by `graph_id`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- Go to an agent and initiate a run that will take a while (long enough
to navigate to a different agent)
  - Navigate: Library -> [another agent]
- [ ] Runs from the first agent don't show up in the runs list of the
other agent
2025-04-17 14:11:09 +00:00
Zamil Majdy
f6a4b036c7 fix(block): Disable LLM blocks parallel tool calls (#9834)
SmartDecisionBlock sometimes tried to be smart by calling multiple tool
calls and our platform does not support this yet.

### Changes 🏗️

Disable parallel tool calls for OpenAI & OpenRouter LLM provider LLM
blocks.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Tested SmartDecisionBlock & AITextGeneratorBlock
2025-04-17 12:58:05 +00:00
Zamil Majdy
c43924cd4e feat(backend): Add RabbitMQ connection cleanup on executor shutdown hook 2025-04-17 01:28:15 +02:00
Zamil Majdy
e3846c22bd fix(backend): Avoid multithreaded pika access (#9832)
### Changes 🏗️

Avoid other threads accessing the channel within the same process.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Manual agent runs
2025-04-16 22:06:07 +00:00
Toran Bruce Richards
9a7a838418 fix(backend): Change node output logging type from info to debug (#9831)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
This PR simply changes the logging type from info to debug of node
outputs in the agent.py file.
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

---------

Co-authored-by: Bentlybro <Github@bentlybro.com>
2025-04-16 20:45:51 +00:00
Toran Bruce Richards
d61d815208 fix(logging): Change node data logging to debug level from info (#9830)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️
This change simply changes the logging level of node inputs and outputs
to debug level. This change is needed because currently logging all node
data causes logs that are too large for the logger to prevent nodes from
running.

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>
2025-04-16 19:22:52 +00:00
Zamil Majdy
44e3770003 fix(backend): Fix execution manager message consuming pattern (#9829)
We have seen instances where the executor gets stuck in a failing
message-consuming loop due to the upstream RabbitMQ being down. The
current message-consuming pattern is not optimal for handling this.

### Changes 🏗️

* Add a retry limit to the execution loop limit.
* Use `basic_consume` instead of `basic_get` for handling message
consumption.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run agents cancel them
2025-04-16 22:54:26 +07:00
Zamil Majdy
c0ee71fb27 fix(frontend/builder): Fix key-value pair input for any non-string types (#9826)
- Resolves #9823 

The key-value pairs input, like those used in CreateDictionaryBlock, are
assumed to be either a numeric or a string type.
When it has `any` type, it was randomly assumed to be a numeric type. 

### Changes 🏗️

Only convert to number when it's explicitly defined to do so on
key-value pair input.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Tried two different key-value pair input: AiTextGenerator &
CreateDictionary
2025-04-16 11:10:50 +00:00
Zamil Majdy
71cdc18674 fix(backend): Fix cancel_execution can only work once (#9825)
### Changes 🏗️

The recent change to the execution cancelation fix turns out to only
work on the first request.
This PR change fixes it by reworking how the thread_cached work on async
functions.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Cancel agent executions multiple times
2025-04-16 10:33:49 +00:00
Zamil Majdy
dc9348ec26 fix(frontend): Fix Input value mixup on Library page (#9821)
### Changes 🏗️

Fix this broken behaviors:
Input data mix-up caused by running two different executions of the same
agent with the same input.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Run agent with old user
- [x] Running two different executions of the same agent with the same
input.
2025-04-16 09:31:07 +00:00
Zamil Majdy
3ccbc31705 Revert: fix(frontend): Fix Input value mixup on Library page & broken marketplace on no onboarding data 2025-04-15 21:28:43 +02:00
Zamil Majdy
7cf0c6fe46 fix(frontend): Fix Input value mixup on Library page & broken marketplace on no onboarding data 2025-04-15 21:25:25 +02:00
Zamil Majdy
c69faa2a94 fix(frontend): Fix Input value mixup on Library page & broken marketplace on no onboarding data 2025-04-15 21:24:39 +02:00
Nicholas Tindle
0c9dbbbe24 Merge branch 'master' into dev 2025-04-15 12:00:02 -05:00
Nicholas Tindle
3e0742f9c5 Spike/infra pooling (#9812)
<!-- Clearly explain the need for these changes: -->
Swap to pooling supabase connections rather than depending on x number
of max open connections

### Changes 🏗️
Adds direct connect URL to be used throughout the system
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Test thoroughly all of the endpoints in the dev env with switched
infra matching pr
  - [x] Follow the new release plan tests
  - [x] Follow the old release plan tests

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>configuration changes</summary>

- Change how we connect to the database to use direct when configured
and database URL when not
  - update prisma for this
  - have default matching database and default
</details>
2025-04-15 15:40:15 +00:00
Krzysztof Czerwinski
d791cdea76 feat(platform): Onboarding Phase 2 (#9736)
### Changes 🏗️

- Update onboarding to give user rewards for completing steps
- Remove `canvas-confetti` lib and add `party-js` instead; the former
didn't allow to play confetti from a component
- Add onboarding videos in `frontend/public/onboarding/`
- Remove Balance (`CreditsCard.tsx`) and add openable `Wallet.tsx` (and
accompanying `WalletTaskGroup.tsx`) instead that displays grouped
onboarding tasks with descriptions and short instructional videos
- Further relevant updates to `useOnboarding`, `types.ts`
- Implement onboarding rewards
- Add `onboarding_reward` function in `credit.py` that is used to reward
user for finished onboarding tasks safely - transaction key is
deterministic, so the same user won't be rewarded twice for the same
step.
  - Add `reward_user` in `onboarding.py`
- Update `UserOnboarding` model and add a migration

<img width="464" alt="Screenshot 2025-04-05 at 6 06 29 PM"
src="https://github.com/user-attachments/assets/fca8d09e-0139-466b-b679-d24117ad01f0"
/>

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Onboarding works
  - [x] Tasks can be completed
  - [x] Rewards are added correctly for all completed tasks
2025-04-12 10:56:59 +00:00
Zamil Majdy
bb92226f5d feat(backend): Remove RPC service from Agent Executor (#9804)
Currently the execution task is not properly distributed between
executors because we need to send the execution request to the execution
server.

The execution manager now accepts the execution request from the message
queue. Thus, we can remove the synchronous RPC system from this service,
let the system focus on executing the agent, and not spare any process
for the HTTP API interface.

This will also reduce the risk of the execution service being too busy
and not able to accept any add execution requests.

### Changes 🏗️

* Remove the RPC system in Agent Executor
* Allow the cancellation of the execution that is still waiting in the
queue (by avoiding it from being executed).
* Make a unified helper for adding an execution request to the system
and move other execution-related helper functions into
`executor/utils.py`.
* Remove non-db connections (redis / rabbitmq) in Database Manager and
let the client manage this by themselves.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Existing CI, some agent runs
2025-04-11 19:03:47 +00:00
Zamil Majdy
f7ca5ac1ba feat(backend/executor): Move execution queue + cancel mechanism to RabbitMQ (#9759)
The graph execution queue is not disk-persisted; when the executor dies,
the executions are lost.

The scope of this issue is migrating the execution queue from an
inter-process queue to a RabbitMQ message queue. A sync client should be
used for this.

- Resolves #9746
- Resolves #9714

### Changes 🏗️

Move the execution manager from multiprocess.Queue into persisted
Rabbit-MQ.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Execute agents.

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>
2025-04-11 14:15:39 +00:00
Abhimanyu Yadav
4621a95bf3 fix(marketplace): Fix small UI bugs (#9800)
Resolving the bugs listed below
- #9796 
- #9797 
- #9798 
- #8998 
- #9799 

### Changes I have made 
- Removed border and set border-radius to `24px` in FeaturedCard
- Removed `white` background from breadcrumbs
- Changed distance between featured section arrow from `28px` to `12px`
- Added `1.5rem` spacing and changed color to `gray-200` on the
creator’s page separator
- Removed focus ring from the Search Library input
- And some small UI changes on marketplace

### Screenshots

<img width="658" alt="Screenshot 2025-04-10 at 3 26 56 PM"
src="https://github.com/user-attachments/assets/22bef6f0-19b9-42a6-8227-fedca33141ba"
/>

<img width="505" alt="Screenshot 2025-04-10 at 3 27 07 PM"
src="https://github.com/user-attachments/assets/2a5409a1-94c6-4d15-a35d-e4ed9b075055"
/>

<img width="1373" alt="Screenshot 2025-04-10 at 3 28 39 PM"
src="https://github.com/user-attachments/assets/046ea726-2a98-4000-abc8-9139fffe80dc"
/>

<img width="368" alt="Screenshot 2025-04-10 at 3 29 07 PM"
src="https://github.com/user-attachments/assets/4e0510ad-f535-4760-a703-651766ff522b"
/>
2025-04-11 13:09:35 +00:00
Abhimanyu Yadav
8d8a6e450f fix(marketplace): Render newline in marketplace description text (#9808)
- fix #9177 

Add `whitespace-pre-line` tailwind property to allow newline rendering
in marketplace description text

### Before

![Screenshot 2025-04-11 at 10 32
23 AM](https://github.com/user-attachments/assets/b07f58b6-218e-4b33-a018-93757e59cd8d)

### After

![Screenshot 2025-04-11 at 10 32
59 AM](https://github.com/user-attachments/assets/f1086ee4-aef3-491a-ba81-cf681086f67b)
2025-04-11 10:50:32 +00:00
Nicholas Tindle
cda07e81d1 feat(frontend, backend): track sentry environment on frontend + sentry init in app services (#9773)
<!-- Clearly explain the need for these changes: -->
We want to be able to filter errors according to where they occur in
sentry so we need to track and include that data. We also are not
logging everything from app services correctly so fix that up

### Changes 🏗️

<!-- Concisely describe all of the changes made in this pull request:
-->
- Adds env tracking for frontend
- adds sentry init in app service spawn

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Tested by running and making sure all events + logs are inserted
into sentry correctly
2025-04-10 16:28:07 +01:00
Abhimanyu Yadav
6156fbb731 fix(marketplace): Fixing margins between headers, divider and content (#9757)
- fix #9003 
- fix - #8969 
- fix #8970 

Adding correct margins in between headers, divider and content.

### Changes made

- Remove any vertical padding or margin from the section.
- Add top and bottom margins to the separator, so the spacing between
sections is handled only by the separator.
- Also, add a size prop in AvatarFallback because its size is currently
broken. It’s not able to extract the size properly from the className.
2025-04-10 16:28:01 +01:00
Abhimanyu Yadav
07a09d802c fix(marketplace): Fix store card style (#9769)
- fix #9222 
- fix #9221 
- fix #8966

### Changes made
- Standardized the height of store cards.
- Corrected spacing and responsiveness behavior.
- Removed horizontal margin and max-width from the featured section.
- Fixed the aspect ratio of the agent image in the store card.
- Now, a normal desktop screen displays 3 columns of agents instead of
4.

<img width="1512" alt="Screenshot 2025-04-07 at 7 09 40 AM"
src="https://github.com/user-attachments/assets/50d3b5c9-4e7c-456e-b5f1-7c0093509bd3"
/>
2025-04-10 12:01:42 +01:00
Nicholas Tindle
88b81f8cb2 Merge branch 'dev' into abhi-9274/postgres-integration 2025-04-03 11:25:56 -05:00
Nicholas Tindle
7294741001 Merge branch 'dev' into abhi-9274/postgres-integration 2025-04-03 11:14:19 -05:00
abhi1992002
ea03c404b1 Add test_mock to Postgres blocks 2025-03-29 18:27:01 +05:30
abhi1992002
79651855c2 Remove command type descriptions in postgres.py 2025-03-29 18:19:29 +05:30
abhi1992002
7977d1b1e5 Add PostgreSQL integration with CRUD operations 2025-03-29 18:09:57 +05:30
194 changed files with 6114 additions and 2230 deletions

View File

@@ -34,6 +34,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:

View File

@@ -36,6 +36,7 @@ jobs:
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
trigger:
needs: migrate

View File

@@ -135,6 +135,7 @@ jobs:
run: poetry run prisma migrate dev --name updates
env:
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
- id: lint
name: Run Linter
@@ -151,12 +152,13 @@ jobs:
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
REDIS_HOST: 'localhost'
REDIS_PORT: '6379'
REDIS_PASSWORD: 'testpassword'
REDIS_HOST: "localhost"
REDIS_PORT: "6379"
REDIS_PASSWORD: "testpassword"
env:
CI: true
@@ -169,8 +171,8 @@ jobs:
# If you want to replace this, you can do so by making our entire system generate
# new credentials for each local user and update the environment variables in
# the backend service, docker composes, and examples
RABBITMQ_DEFAULT_USER: 'rabbitmq_user_default'
RABBITMQ_DEFAULT_PASS: 'k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7'
RABBITMQ_DEFAULT_USER: "rabbitmq_user_default"
RABBITMQ_DEFAULT_PASS: "k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7"
# - name: Upload coverage reports to Codecov
# uses: codecov/codecov-action@v4

View File

@@ -16,7 +16,7 @@ jobs:
# operations-per-run: 5000
stale-issue-message: >
This issue has automatically been marked as _stale_ because it has not had
any activity in the last 50 days. You can _unstale_ it by commenting or
any activity in the last 170 days. You can _unstale_ it by commenting or
removing the label. Otherwise, this issue will be closed in 10 days.
stale-pr-message: >
This pull request has automatically been marked as _stale_ because it has
@@ -25,7 +25,7 @@ jobs:
close-issue-message: >
This issue was closed automatically because it has been stale for 10 days
with no activity.
days-before-stale: 100
days-before-stale: 170
days-before-close: 10
# Do not touch meta issues:
exempt-issue-labels: meta,fridge,project management

View File

@@ -1,20 +1,59 @@
import inspect
import threading
from typing import Callable, ParamSpec, TypeVar
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
P = ParamSpec("P")
R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@overload
def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ...
def thread_cached(
func: Callable[P, R] | Callable[P, Awaitable[R]],
) -> Callable[P, R] | Callable[P, Awaitable[R]]:
thread_local = threading.local()
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
def _clear():
if hasattr(thread_local, "cache"):
del thread_local.cache
return wrapper
if inspect.iscoroutinefunction(func):
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = await cast(Callable[P, Awaitable[R]], func)(
*args, **kwargs
)
return cache[key]
setattr(async_wrapper, "clear_cache", _clear)
return async_wrapper
else:
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
setattr(sync_wrapper, "clear_cache", _clear)
return sync_wrapper
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()

View File

@@ -31,7 +31,7 @@ class RedisKeyedMutex:
try:
yield
finally:
if lock.locked():
if lock.locked() and lock.owned():
lock.release()
def acquire(self, key: Any) -> "RedisLock":

View File

@@ -8,6 +8,7 @@ DB_CONNECT_TIMEOUT=60
DB_POOL_TIMEOUT=300
DB_SCHEMA=platform
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
PRISMA_SCHEMA="postgres/schema.prisma"
# EXECUTOR

View File

@@ -73,7 +73,6 @@ FROM server_dependencies AS server
COPY autogpt_platform/backend /app/autogpt_platform/backend
RUN poetry install --no-ansi --only-root
ENV DATABASE_URL=""
ENV PORT=8000
CMD ["poetry", "run", "rest"]

View File

@@ -1,8 +1,6 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
from backend.data.block import (
Block,
BlockCategory,
@@ -19,21 +17,6 @@ from backend.util import json
logger = logging.getLogger(__name__)
@thread_cached
def get_executor_manager_client():
from backend.executor import ExecutionManager
from backend.util.service import get_service_client
return get_service_client(ExecutionManager)
@thread_cached
def get_event_bus():
from backend.data.execution import RedisExecutionEventBus
return RedisExecutionEventBus()
class AgentExecutorBlock(Block):
class Input(BlockSchema):
user_id: str = SchemaField(description="User ID")
@@ -76,23 +59,23 @@ class AgentExecutorBlock(Block):
def run(self, input_data: Input, **kwargs) -> BlockOutput:
from backend.data.execution import ExecutionEventType
from backend.executor import utils as execution_utils
executor_manager = get_executor_manager_client()
event_bus = get_event_bus()
event_bus = execution_utils.get_execution_event_bus()
graph_exec = executor_manager.add_execution(
graph_exec = execution_utils.add_graph_execution(
graph_id=input_data.graph_id,
graph_version=input_data.graph_version,
user_id=input_data.user_id,
data=input_data.data,
inputs=input_data.data,
)
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.graph_exec_id}"
log_id = f"Graph #{input_data.graph_id}-V{input_data.graph_version}, exec-id: {graph_exec.id}"
logger.info(f"Starting execution of {log_id}")
for event in event_bus.listen(
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_exec_id=graph_exec.id,
):
if event.event_type == ExecutionEventType.GRAPH_EXEC_UPDATE:
if event.status in [
@@ -105,7 +88,7 @@ class AgentExecutorBlock(Block):
else:
continue
logger.info(
logger.debug(
f"Execution {log_id} produced input {event.input_data} output {event.output_data}"
)
@@ -123,5 +106,7 @@ class AgentExecutorBlock(Block):
continue
for output_data in event.output_data.get("output", []):
logger.info(f"Execution {log_id} produced {output_name}: {output_data}")
logger.debug(
f"Execution {log_id} produced {output_name}: {output_data}"
)
yield output_name, output_data

View File

@@ -88,6 +88,33 @@ class StoreValueBlock(Block):
yield "output", input_data.data or input_data.input
class PrintToConsoleBlock(Block):
class Input(BlockSchema):
text: Any = SchemaField(description="The data to print to the console.")
class Output(BlockSchema):
output: Any = SchemaField(description="The data printed to the console.")
status: str = SchemaField(description="The status of the print operation.")
def __init__(self):
super().__init__(
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
description="Print the given text to the console, this is used for a debugging purpose.",
categories={BlockCategory.BASIC},
input_schema=PrintToConsoleBlock.Input,
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
test_output=[
("output", "Hello, World!"),
("status", "printed"),
],
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
yield "output", input_data.text
yield "status", "printed"
class FindInDictionaryBlock(Block):
class Input(BlockSchema):
input: Any = SchemaField(description="Dictionary to lookup from")

View File

@@ -3,7 +3,7 @@ from googleapiclient.discovery import build
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Settings
from backend.util.settings import AppEnvironment, Settings
from ._auth import (
GOOGLE_OAUTH_IS_CONFIGURED,
@@ -36,13 +36,15 @@ class GoogleSheetsReadBlock(Block):
)
def __init__(self):
settings = Settings()
super().__init__(
id="5724e902-3635-47e9-a108-aaa0263a4988",
description="This block reads data from a Google Sheets spreadsheet.",
categories={BlockCategory.DATA},
input_schema=GoogleSheetsReadBlock.Input,
output_schema=GoogleSheetsReadBlock.Output,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
disabled=not GOOGLE_OAUTH_IS_CONFIGURED
or settings.config.app_env == AppEnvironment.PRODUCTION,
test_input={
"spreadsheet_id": "1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms",
"range": "Sheet1!A1:B2",

View File

@@ -82,7 +82,15 @@ class SendWebRequestBlock(Block):
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text
if input_data.json_format:
if response.status_code == 204 or not response.content.strip():
result = None
else:
result = response.json()
else:
result = response.text
yield "response", result
except HTTPError as e:

View File

@@ -9,7 +9,6 @@ from typing import Any, Iterable, List, Literal, NamedTuple, Optional
import anthropic
import ollama
import openai
from anthropic import NotGiven
from anthropic.types import ToolParam
from groq import Groq
from pydantic import BaseModel, SecretStr
@@ -90,14 +89,17 @@ class LlmModelMeta(EnumMeta):
class LlmModel(str, Enum, metaclass=LlmModelMeta):
# OpenAI models
O3_MINI = "o3-mini"
O3 = "o3-2025-04-16"
O1 = "o1"
O1_PREVIEW = "o1-preview"
O1_MINI = "o1-mini"
GPT41 = "gpt-4.1-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_5_HAIKU = "claude-3-5-haiku-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
@@ -118,6 +120,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
GEMINI_FLASH_1_5 = "google/gemini-flash-1.5"
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
GROK_BETA = "x-ai/grok-beta"
MISTRAL_NEMO = "mistralai/mistral-nemo"
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
@@ -157,12 +160,14 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
MODEL_METADATA = {
# https://platform.openai.com/docs/models
LlmModel.O3: ModelMetadata("openai", 200000, 100000),
LlmModel.O3_MINI: ModelMetadata("openai", 200000, 100000), # o3-mini-2025-01-31
LlmModel.O1: ModelMetadata("openai", 200000, 100000), # o1-2024-12-17
LlmModel.O1_PREVIEW: ModelMetadata(
"openai", 128000, 32768
), # o1-preview-2024-09-12
LlmModel.O1_MINI: ModelMetadata("openai", 128000, 65536), # o1-mini-2024-09-12
LlmModel.GPT41: ModelMetadata("openai", 1047576, 32768),
LlmModel.GPT4O_MINI: ModelMetadata(
"openai", 128000, 16384
), # gpt-4o-mini-2024-07-18
@@ -172,6 +177,9 @@ MODEL_METADATA = {
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, 4096), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-7-sonnet-20250219
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata(
"anthropic", 200000, 8192
), # claude-3-5-sonnet-20241022
@@ -197,6 +205,7 @@ MODEL_METADATA = {
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768, None),
# https://openrouter.ai/models
LlmModel.GEMINI_FLASH_1_5: ModelMetadata("open_router", 1000000, 8192),
LlmModel.GEMINI_2_5_PRO: ModelMetadata("open_router", 1050000, 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 131072, 131072),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 128000, 4096),
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata("open_router", 128000, 4096),
@@ -249,7 +258,7 @@ class LLMResponse(BaseModel):
def convert_openai_tool_fmt_to_anthropic(
openai_tools: list[dict] | None = None,
) -> Iterable[ToolParam] | NotGiven:
) -> Iterable[ToolParam] | anthropic.NotGiven:
"""
Convert OpenAI tool format to Anthropic tool format.
"""
@@ -279,6 +288,13 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
@@ -287,6 +303,7 @@ def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -309,7 +326,14 @@ def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or 4096
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
max_tokens = max(min(available_tokens, model_max_output, user_max), 0)
if provider == "openai":
tools_param = tools if tools else openai.NOT_GIVEN
@@ -332,6 +356,9 @@ def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
if response.choices[0].message.tool_calls:
@@ -462,6 +489,7 @@ def llm_call(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
)
return LLMResponse(
raw_response=response.get("response") or "",
@@ -487,6 +515,9 @@ def llm_call(
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
)
# If there's no response, raise an error
@@ -757,6 +788,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.exception(f"Error calling LLM: {e}")
if (
"maximum context length" in str(e).lower()
or "token limit" in str(e).lower()
):
if input_data.max_tokens is None:
input_data.max_tokens = llm_model.max_output_tokens or 4096
input_data.max_tokens = int(input_data.max_tokens * 0.85)
logger.debug(
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
)
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(

View File

@@ -0,0 +1,734 @@
from enum import Enum
from typing import Any, List, Literal, Optional
import psycopg2
from psycopg2.extras import RealDictCursor
from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
PostgresCredentials = UserPasswordCredentials
PostgresCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.POSTGRES],
Literal["user_password"],
]
def PostgresCredentialsField() -> PostgresCredentialsInput:
"""Creates a Postgres credentials input on a block."""
return CredentialsField(
description="The Postgres integration requires a username and password.",
)
TEST_POSTGRES_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="postgres",
username=SecretStr("mock-postgres-username"),
password=SecretStr("mock-postgres-password"),
title="Mock Postgres credentials",
)
TEST_POSTGRES_CREDENTIALS_INPUT = {
"provider": TEST_POSTGRES_CREDENTIALS.provider,
"id": TEST_POSTGRES_CREDENTIALS.id,
"type": TEST_POSTGRES_CREDENTIALS.type,
"title": TEST_POSTGRES_CREDENTIALS.title,
}
class CommandType(str, Enum):
TRUNCATE = "TRUNCATE"
DELETE = "DELETE"
DROP = "DROP"
class ConditionOperator(str, Enum):
EQUALS = "="
NOT_EQUALS = "<>"
GREATER_THAN = ">"
LESS_THAN = "<"
GREATER_EQUALS = ">="
LESS_EQUALS = "<="
LIKE = "LIKE"
IN = "IN"
class Condition(BaseModel):
column: str
operator: ConditionOperator
value: Any
class CombineCondition(str, Enum):
AND = "AND"
OR = "OR"
class PostgresDeleteBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres",advanced=False)
schema_: str = SchemaField(description="Schema name", default="public",advanced=False)
table: str = SchemaField(description="Table name")
command: CommandType = SchemaField(
description="Command type to execute",
default=CommandType.DELETE,
advanced=False
)
conditions: List[Condition] = SchemaField(
description="Conditions for DELETE command",
default=[],
advanced=False
)
combine_conditions: CombineCondition = SchemaField(
description="How to combine multiple conditions",
default=CombineCondition.AND,
advanced=False
)
restart_sequences: bool = SchemaField(
description="Restart any auto-incrementing counters associated with the table after truncate",
default=False
)
cascade: bool = SchemaField(
description="This automatically truncates any tables that reference the target table via foreign keys, Only used for Truncate and Drop",
default=False
)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
rows_affected: Optional[int] = SchemaField(description="Number of rows affected")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="81b103ad-0fa9-47d3-a18f-2ea96579e3bb",
description="Delete, truncate or drop data from a PostgreSQL table",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresDeleteBlock.Input,
output_schema=PostgresDeleteBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"table": "users",
"command": CommandType.DELETE,
"conditions": [
{"column": "id", "operator": ConditionOperator.EQUALS, "value": 1}
]
},
test_output=[
("success", True),
("rows_affected", 1)
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("rows_affected", 1)
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor() as cursor:
rows_affected = 0
if input_data.command == CommandType.TRUNCATE:
sql = f"TRUNCATE TABLE {input_data.schema_}.{input_data.table}"
if input_data.restart_sequences:
sql += " RESTART IDENTITY"
if input_data.cascade:
sql += " CASCADE"
cursor.execute(sql)
elif input_data.command == CommandType.DELETE:
if input_data.conditions:
where_clauses = []
values = []
for condition in input_data.conditions:
if condition.operator == ConditionOperator.IN:
placeholders = ", ".join(["%s"] * len(condition.value))
where_clauses.append(f"{condition.column} IN ({placeholders})")
values.extend(condition.value)
else:
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
values.append(condition.value)
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
sql = f"DELETE FROM {input_data.schema_}.{input_data.table} WHERE {where_clause}"
cursor.execute(sql, values)
else:
sql = f"DELETE FROM {input_data.schema_}.{input_data.table}"
cursor.execute(sql)
rows_affected = cursor.rowcount
elif input_data.command == CommandType.DROP:
sql = f"DROP TABLE {input_data.schema_}.{input_data.table}"
if input_data.cascade:
sql += " CASCADE"
cursor.execute(sql)
conn.commit()
yield "success", True
yield "rows_affected", rows_affected
except Exception as e:
if conn:
conn.rollback()
yield "error", str(e)
finally:
if conn:
conn.close() # Just for extra safety
class PostgresExecuteQueryBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
query: str = SchemaField(description="SQL query to execute")
parameters: List[Any] = SchemaField(description="Query parameters", default=[], advanced=False)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
result: Any = SchemaField(description="Query results or affected rows")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="c5d18dc8-ee3c-4366-ba99-a3996b7a4e78",
description="Executes an SQL query on a PostgreSQL database.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresExecuteQueryBlock.Input,
output_schema=PostgresExecuteQueryBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"query": "SELECT * FROM users WHERE id = %s",
"parameters": [1]
},
test_output=[
("success", True),
("result", [{"id": 1, "name": "Test User"}])
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("result", [{"id": 1, "name": "Test User"}])
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
# Using RealDictCursor to return data as dict, otherwise cursor return data as tuple
cursor.execute(input_data.query, input_data.parameters)
if cursor.description:
result = cursor.fetchall()
result = [dict(row) for row in result]
else:
# Query doesn't return data (INSERT, UPDATE, DELETE)
result = cursor.rowcount # Number of rows affected by executing this query
conn.commit()
yield "success", True
yield "result", result
except Exception as e:
if conn:
conn.rollback()
yield "error", str(e)
finally:
if conn:
conn.close()
class PostgresInsertBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
table: str = SchemaField(description="Table name")
data: List[dict] = SchemaField(description="Data to insert", default=[])
return_inserted_rows: bool = SchemaField(description="Return inserted rows", default=False)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
inserted_rows: List[dict] = SchemaField(description="Inserted rows if requested")
rows_affected: int = SchemaField(description="Number of rows affected")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="82a6c2d5-4c6f-4e3a-aba2-feae15c03cbe",
description="Inserts rows into a PostgreSQL table",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresInsertBlock.Input,
output_schema=PostgresInsertBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"table": "users",
"data": [{"name": "Test User", "email": "test@example.com"}],
"return_inserted_rows": True
},
test_output=[
("success", True),
("rows_affected", 1),
("inserted_rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("rows_affected", 1),
("inserted_rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
if not input_data.data:
yield "success", True
yield "rows_affected", 0
yield "inserted_rows", []
return
columns = list(input_data.data[0].keys())
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO {input_data.schema_}.{input_data.table} ({cols_str}) VALUES ({placeholders})"
if input_data.return_inserted_rows:
sql += " RETURNING *"
inserted_rows = []
rows_affected = 0
for row in input_data.data:
values = [row[col] for col in columns]
cursor.execute(sql, values)
rows_affected += cursor.rowcount
if input_data.return_inserted_rows:
inserted_rows.extend([dict(row) for row in cursor.fetchall()])
conn.commit()
yield "success", True
yield "rows_affected", rows_affected
yield "inserted_rows", inserted_rows
except Exception as e:
if conn:
conn.rollback()
yield "success", False
yield "error", str(e)
finally:
if conn:
conn.close()
class PostgresInsertOrUpdateBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
table: str = SchemaField(description="Table name")
data: List[dict] = SchemaField(description="Data to insert or update", default=[])
key_columns: List[str] = SchemaField(description="Columns to use as unique constraint", default=[])
return_affected_rows: bool = SchemaField(description="Return affected rows", default=False)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
affected_rows: List[dict] = SchemaField(description="Affected rows if requested")
rows_affected: int = SchemaField(description="Number of rows affected")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="fa8e0ce3-5b8c-49e2-a3b7-dca21f5c4a72",
description="Inserts or updates rows in a PostgreSQL table using ON CONFLICT",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresInsertOrUpdateBlock.Input,
output_schema=PostgresInsertOrUpdateBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"table": "users",
"data": [{"id": 1, "name": "Updated User", "email": "updated@example.com"}],
"key_columns": ["id"],
"return_affected_rows": True
},
test_output=[
("success", True),
("rows_affected", 1),
("affected_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("rows_affected", 1),
("affected_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
if not input_data.data or not input_data.key_columns:
yield "success", True
yield "rows_affected", 0
yield "affected_rows", []
return
affected_rows = []
rows_affected = 0
for row in input_data.data:
columns = list(row.keys())
cols_str = ", ".join(columns)
placeholders = ", ".join(["%s"] * len(columns))
conflict_cols = ", ".join(input_data.key_columns)
update_cols = ", ".join(
f"{col} = EXCLUDED.{col}" for col in columns if col not in input_data.key_columns
)
sql = (
f"INSERT INTO {input_data.schema_}.{input_data.table} ({cols_str}) "
f"VALUES ({placeholders}) ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_cols}"
)
if input_data.return_affected_rows:
sql += " RETURNING *"
values = [row[col] for col in columns]
cursor.execute(sql, values)
rows_affected += cursor.rowcount
if input_data.return_affected_rows:
affected_rows.extend([dict(row) for row in cursor.fetchall()])
conn.commit()
yield "success", True
yield "rows_affected", rows_affected
yield "affected_rows", affected_rows
except Exception as e:
if conn:
conn.rollback()
yield "success", False
yield "error", str(e)
finally:
if conn:
conn.close()
class PostgresSelectBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
table: str = SchemaField(description="Table name")
columns: List[str] = SchemaField(description="Columns to select (empty for all columns)", default=[])
conditions: List[Condition] = SchemaField(description="Conditions for WHERE clause", default=[], advanced=False)
combine_conditions: CombineCondition = SchemaField(
description="How to combine multiple conditions",
default=CombineCondition.AND,
advanced=False
)
limit: Optional[int] = SchemaField(description="Maximum number of rows to return", default=None)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
rows: List[dict] = SchemaField(description="Selected rows")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="e7c92ea5-1d2a-4e9c-bb89-376dfcbea342",
description="Selects rows from a PostgreSQL table",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresSelectBlock.Input,
output_schema=PostgresSelectBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"table": "users",
"columns": ["id", "name", "email"],
"conditions": [
{"column": "id", "operator": ConditionOperator.GREATER_THAN, "value": 0}
],
"limit": 100
},
test_output=[
("success", True),
("rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("rows", [{"id": 1, "name": "Test User", "email": "test@example.com"}])
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cols = ", ".join(input_data.columns) if input_data.columns else "*"
sql = f"SELECT {cols} FROM {input_data.schema_}.{input_data.table}"
values = []
if input_data.conditions:
where_clauses = []
for condition in input_data.conditions:
if condition.operator == ConditionOperator.IN:
placeholders = ", ".join(["%s"] * len(condition.value))
where_clauses.append(f"{condition.column} IN ({placeholders})")
values.extend(condition.value)
else:
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
values.append(condition.value)
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
sql += f" WHERE {where_clause}"
if input_data.limit is not None:
sql += f" LIMIT {input_data.limit}"
cursor.execute(sql, values)
rows = [dict(row) for row in cursor.fetchall()]
yield "success", True
yield "rows", rows
except Exception as e:
if conn:
conn.rollback()
yield "success", False
yield "error", str(e)
finally:
if conn:
conn.close()
class PostgresUpdateBlock(Block):
class Input(BlockSchema):
credentials: PostgresCredentialsInput = PostgresCredentialsField()
host: str = SchemaField(description="Database host", advanced=False)
port: int = SchemaField(description="Database port", advanced=False)
database: str = SchemaField(description="Database name", default="postgres", advanced=False)
schema_: str = SchemaField(description="Schema name", default="public", advanced=False)
table: str = SchemaField(description="Table name")
set_data: dict = SchemaField(description="Column-value pairs to update", default={})
conditions: List[Condition] = SchemaField(description="Conditions for WHERE clause", default=[], advanced=False)
combine_conditions: CombineCondition = SchemaField(
description="How to combine multiple conditions",
default=CombineCondition.AND,
advanced=False
)
return_updated_rows: bool = SchemaField(description="Return updated rows", default=False)
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded")
rows_affected: int = SchemaField(description="Number of rows affected")
updated_rows: List[dict] = SchemaField(description="Updated rows if requested")
error: str = SchemaField(description="Error message if operation failed")
def __init__(self):
super().__init__(
id="a4e3d8c2-7f1b-49d0-8bc6-e479ea3d5752",
description="Updates rows in a PostgreSQL table",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=PostgresUpdateBlock.Input,
output_schema=PostgresUpdateBlock.Output,
test_credentials=TEST_POSTGRES_CREDENTIALS,
test_input={
"credentials": TEST_POSTGRES_CREDENTIALS_INPUT,
"host": "localhost",
"port": 5432,
"database": "test_db",
"schema_": "public",
"table": "users",
"set_data": {"name": "Updated User", "email": "updated@example.com"},
"conditions": [
{"column": "id", "operator": ConditionOperator.EQUALS, "value": 1}
],
"return_updated_rows": True
},
test_output=[
("success", True),
("rows_affected", 1),
("updated_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("rows_affected", 1),
("updated_rows", [{"id": 1, "name": "Updated User", "email": "updated@example.com"}])
]
},
)
def run(
self, input_data: Input, *, credentials: PostgresCredentials, **kwargs
) -> BlockOutput:
conn = None
try:
conn = psycopg2.connect(
host=input_data.host,
port=input_data.port,
database=input_data.database,
user=credentials.username.get_secret_value(),
password=credentials.password.get_secret_value()
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
if not input_data.set_data:
yield "success", True
yield "rows_affected", 0
yield "updated_rows", []
return
set_clause = ", ".join(f"{k} = %s" for k in input_data.set_data.keys())
sql = f"UPDATE {input_data.schema_}.{input_data.table} SET {set_clause}"
values = list(input_data.set_data.values())
if input_data.conditions:
where_clauses = []
for condition in input_data.conditions:
if condition.operator == ConditionOperator.IN:
placeholders = ", ".join(["%s"] * len(condition.value))
where_clauses.append(f"{condition.column} IN ({placeholders})")
values.extend(condition.value)
else:
where_clauses.append(f"{condition.column} {condition.operator.value} %s")
values.append(condition.value)
where_clause = f" {input_data.combine_conditions.value} ".join(where_clauses)
sql += f" WHERE {where_clause}"
if input_data.return_updated_rows:
sql += " RETURNING *"
cursor.execute(sql, values)
rows_affected = cursor.rowcount
updated_rows = []
if input_data.return_updated_rows:
updated_rows = [dict(row) for row in cursor.fetchall()]
conn.commit()
yield "success", True
yield "rows_affected", rows_affected
yield "updated_rows", updated_rows
except Exception as e:
if conn:
conn.rollback()
yield "success", False
yield "error", str(e)
finally:
if conn:
conn.close()

View File

@@ -26,10 +26,10 @@ logger = logging.getLogger(__name__)
@thread_cached
def get_database_manager_client():
from backend.executor import DatabaseManager
from backend.executor import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
@@ -246,6 +246,10 @@ class SmartDecisionMakerBlock(Block):
test_credentials=llm.TEST_CREDENTIALS,
)
@staticmethod
def cleanup(s: str):
return re.sub(r"[^a-zA-Z0-9_-]", "_", s).lower()
@staticmethod
def _create_block_function_signature(
sink_node: "Node", links: list["Link"]
@@ -266,7 +270,7 @@ class SmartDecisionMakerBlock(Block):
block = sink_node.block
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
@@ -281,7 +285,7 @@ class SmartDecisionMakerBlock(Block):
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
@@ -326,7 +330,7 @@ class SmartDecisionMakerBlock(Block):
)
tool_function: dict[str, Any] = {
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
"description": sink_graph_meta.description,
}
@@ -341,7 +345,7 @@ class SmartDecisionMakerBlock(Block):
in sink_block_input_schema["properties"][link.sink_name]
else f"The {link.sink_name} of the tool"
)
properties[link.sink_name.lower()] = {
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
@@ -491,6 +495,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=False,
)
if not response.tool_calls:
@@ -502,7 +507,7 @@ class SmartDecisionMakerBlock(Block):
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -28,6 +28,7 @@ from backend.util.settings import Config
from .model import (
ContributorDetails,
Credentials,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
@@ -203,6 +204,15 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_credentials_fields_info(cls) -> dict[str, CredentialsFieldInfo]:
return {
field_name: CredentialsFieldInfo.model_validate(
cls.get_field_schema(field_name), by_alias=True
)
for field_name in cls.get_credentials_fields().keys()
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@@ -509,6 +519,7 @@ async def initialize_blocks() -> None:
)
def get_block(block_id: str) -> Block | None:
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> Block[BlockSchema, BlockSchema] | None:
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -36,14 +36,17 @@ from backend.integrations.credentials_store import (
# =============== Configure the cost for each LLM Model call =============== #
MODEL_COST: dict[LlmModel, int] = {
LlmModel.O3: 7,
LlmModel.O3_MINI: 2, # $1.10 / $4.40
LlmModel.O1: 16, # $15 / $60
LlmModel.O1_PREVIEW: 16,
LlmModel.O1_MINI: 4,
LlmModel.GPT41: 2,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_3_7_SONNET: 5,
LlmModel.CLAUDE_3_5_SONNET: 4,
LlmModel.CLAUDE_3_5_HAIKU: 1, # $0.80 / $4.00
LlmModel.CLAUDE_3_HAIKU: 1,
@@ -60,6 +63,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.DEEPSEEK_LLAMA_70B: 1, # ? / ?
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.GEMINI_FLASH_1_5: 1,
LlmModel.GEMINI_2_5_PRO: 4,
LlmModel.GROK_BETA: 5,
LlmModel.MISTRAL_NEMO: 1,
LlmModel.COHERE_COMMAND_R_08_2024: 1,

View File

@@ -1,8 +1,8 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, cast
import stripe
from autogpt_libs.utils.cache import thread_cached
@@ -11,6 +11,7 @@ from prisma.enums import (
CreditRefundRequestStatus,
CreditTransactionType,
NotificationType,
OnboardingStep,
)
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
@@ -19,7 +20,7 @@ from prisma.types import (
CreditTransactionCreateInput,
CreditTransactionWhereInput,
)
from tenacity import retry, stop_after_attempt, wait_exponential
from pydantic import BaseModel
from backend.data import db
from backend.data.block_cost_config import BLOCK_COSTS
@@ -27,14 +28,17 @@ from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
TopUpType,
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications import NotificationManagerClient
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -44,6 +48,17 @@ logger = logging.getLogger(__name__)
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: dict[str, Any] | None = None
reason: str | None = None
class UserCreditBase(ABC):
@abstractmethod
async def get_credits(self, user_id: str) -> int:
@@ -121,6 +136,18 @@ class UserCreditBase(ABC):
"""
pass
@abstractmethod
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
"""
Reward the user with credits for completing an onboarding step.
Won't reward if the user has already received credits for the step.
Args:
user_id (str): The user ID.
step (OnboardingStep): The onboarding step.
"""
pass
@abstractmethod
async def top_up_intent(self, user_id: str, amount: int) -> str:
"""
@@ -249,11 +276,7 @@ class UserCreditBase(ABC):
)
return transaction_balance, transaction_time
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
@func_retry
async def _enable_transaction(
self,
transaction_key: str,
@@ -352,21 +375,19 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def notification_client(self) -> NotificationManagerClient:
return get_service_client(NotificationManagerClient)
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await asyncio.to_thread(
lambda: self.notification_client().queue_notification(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
await self.notification_client().queue_notification_async(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
)
)
@@ -396,6 +417,7 @@ class UserCredit(UserCreditBase):
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
top_up_type=TopUpType.AUTO,
)
except Exception as e:
# Failed top-up is not critical, we can move on.
@@ -405,8 +427,30 @@ class UserCredit(UserCreditBase):
return balance
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)
async def top_up_credits(
self,
user_id: str,
amount: int,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
):
await self._top_up_credits(
user_id=user_id, amount=amount, top_up_type=top_up_type
)
async def onboarding_reward(self, user_id: str, credits: int, step: OnboardingStep):
try:
await self._add_transaction(
user_id=user_id,
amount=credits,
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"REWARD-{user_id}-{step.value}",
metadata=Json(
{"reason": f"Reward for completing {step.value} onboarding step."}
),
)
except UniqueViolationError:
# Already rewarded for this step
pass
async def top_up_refund(
self, user_id: str, transaction_key: str, metadata: dict[str, str]
@@ -571,7 +615,7 @@ class UserCredit(UserCreditBase):
evidence_text += (
f"- {tx.description}: Amount ${tx.amount / 100:.2f} on {tx.transaction_time.isoformat()}, "
f"resulting balance ${tx.balance / 100:.2f} {additional_comment}\n"
f"resulting balance ${tx.running_balance / 100:.2f} {additional_comment}\n"
)
evidence_text += (
"\nThis evidence demonstrates that the transaction was authorized and that the charged amount was used to render the service as agreed."
@@ -590,7 +634,24 @@ class UserCredit(UserCreditBase):
amount: int,
key: str | None = None,
ceiling_balance: int | None = None,
top_up_type: TopUpType = TopUpType.UNCATEGORIZED,
metadata: dict | None = None,
):
# init metadata, without sharing it with the world
metadata = metadata or {}
if not metadata["reason"]:
match top_up_type:
case TopUpType.MANUAL:
metadata["reason"] = {"reason": f"Top up credits for {user_id}"}
case TopUpType.AUTO:
metadata["reason"] = {
"reason": f"Auto top up credits for {user_id}"
}
case _:
metadata["reason"] = {
"reason": f"Top up reason unknown for {user_id}"
}
if amount < 0:
raise ValueError(f"Top up amount must not be negative: {amount}")
@@ -613,6 +674,7 @@ class UserCredit(UserCreditBase):
is_active=False,
transaction_key=key,
ceiling_balance=ceiling_balance,
metadata=(Json(metadata)),
)
customer_id = await get_stripe_customer_id(user_id)
@@ -755,10 +817,15 @@ class UserCredit(UserCreditBase):
# Check the Checkout Session's payment_status property
# to determine if fulfillment should be performed
if checkout_session.payment_status in ["paid", "no_payment_required"]:
assert isinstance(checkout_session.payment_intent, stripe.PaymentIntent)
if payment_intent := checkout_session.payment_intent:
assert isinstance(payment_intent, stripe.PaymentIntent)
new_transaction_key = payment_intent.id
else:
new_transaction_key = None
await self._enable_transaction(
transaction_key=credit_transaction.transactionKey,
new_transaction_key=checkout_session.payment_intent.id,
new_transaction_key=new_transaction_key,
user_id=credit_transaction.userId,
metadata=Json(checkout_session),
)
@@ -791,8 +858,9 @@ class UserCredit(UserCreditBase):
take=transaction_count_limit,
)
# doesn't fill current_balance, reason, user_email, admin_email, or extra_data
grouped_transactions: dict[str, UserTransaction] = defaultdict(
lambda: UserTransaction()
lambda: UserTransaction(user_id=user_id)
)
tx_time = None
for t in transactions:
@@ -822,7 +890,7 @@ class UserCredit(UserCreditBase):
if tx_time > gt.transaction_time:
gt.transaction_time = tx_time
gt.balance = t.runningBalance or 0
gt.running_balance = t.runningBalance or 0
return TransactionHistory(
transactions=list(grouped_transactions.values()),
@@ -872,6 +940,7 @@ class BetaUserCredit(UserCredit):
amount=max(self.num_user_credits_refill - balance, 0),
transaction_type=CreditTransactionType.GRANT,
transaction_key=f"MONTHLY-CREDIT-TOP-UP-{cur_time}",
metadata=Json({"reason": "Monthly credit refill"}),
)
return balance
except UniqueViolationError:
@@ -881,7 +950,7 @@ class BetaUserCredit(UserCredit):
class DisabledUserCredit(UserCreditBase):
async def get_credits(self, *args, **kwargs) -> int:
return 0
return 100
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
return TransactionHistory(transactions=[], next_transaction_time=None)
@@ -895,6 +964,9 @@ class DisabledUserCredit(UserCreditBase):
async def top_up_credits(self, *args, **kwargs):
pass
async def onboarding_reward(self, *args, **kwargs):
pass
async def top_up_intent(self, *args, **kwargs) -> str:
return ""
@@ -956,3 +1028,81 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.topUpConfig)
async def admin_get_user_history(
page: int = 1,
page_size: int = 20,
search: str | None = None,
transaction_filter: CreditTransactionType | None = None,
) -> UserHistoryResponse:
if page < 1 or page_size < 1:
raise ValueError("Invalid pagination input")
where_clause: CreditTransactionWhereInput = {}
if transaction_filter:
where_clause["type"] = transaction_filter
if search:
where_clause["OR"] = [
{"userId": {"contains": search, "mode": "insensitive"}},
{"User": {"is": {"email": {"contains": search, "mode": "insensitive"}}}},
{"User": {"is": {"name": {"contains": search, "mode": "insensitive"}}}},
]
transactions = await CreditTransaction.prisma().find_many(
where=where_clause,
skip=(page - 1) * page_size,
take=page_size,
include={"User": True},
order={"createdAt": "desc"},
)
total = await CreditTransaction.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
history = []
for tx in transactions:
admin_id = ""
admin_email = ""
reason = ""
metadata: dict = cast(dict, tx.metadata) or {}
if metadata:
admin_id = metadata.get("admin_id")
admin_email = (
(await get_user_email_by_id(admin_id) or f"Unknown Admin: {admin_id}")
if admin_id
else ""
)
reason = metadata.get("reason", "No reason provided")
balance, last_update = await get_user_credit_model()._get_credits(tx.userId)
history.append(
UserTransaction(
transaction_key=tx.transactionKey,
transaction_time=tx.createdAt,
transaction_type=tx.type,
amount=tx.amount,
current_balance=balance,
running_balance=tx.runningBalance or 0,
user_id=tx.userId,
user_email=(
tx.User.email
if tx.User
else (await get_user_by_id(tx.userId)).email
),
reason=reason,
admin_email=admin_email,
extra_data=str(metadata),
)
)
return UserHistoryResponse(
history=history,
pagination=Pagination(
total_items=total,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)

View File

@@ -62,10 +62,10 @@ async def connect():
# Connection acquired from a pool like Supabase somehow still possibly allows
# the db client obtains a connection but still reject query connection afterward.
try:
await prisma.execute_raw("SELECT 1")
except Exception as e:
raise ConnectionError("Failed to connect to Prisma.") from e
# try:
# await prisma.execute_raw("SELECT 1")
# except Exception as e:
# raise ConnectionError("Failed to connect to Prisma.") from e
@conn_retry("Prisma", "Releasing connection")

View File

@@ -34,18 +34,17 @@ from pydantic import BaseModel
from pydantic.fields import Field
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock
from backend.util import type as type_utils
from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .block import BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import (
EXECUTION_RESULT_INCLUDE,
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
T = TypeVar("T")
@@ -203,6 +202,15 @@ class GraphExecutionWithNodes(GraphExecution):
node_executions=node_executions,
)
def to_graph_execution_entry(self):
return GraphExecutionEntry(
user_id=self.user_id,
graph_id=self.graph_id,
graph_version=self.graph_version or 0,
graph_exec_id=self.id,
node_credentials_input_map={}, # FIXME
)
class NodeExecutionResult(BaseModel):
user_id: str
@@ -260,6 +268,17 @@ class NodeExecutionResult(BaseModel):
end_time=_node_exec.endedTime,
)
def to_node_execution_entry(self) -> "NodeExecutionEntry":
return NodeExecutionEntry(
user_id=self.user_id,
graph_exec_id=self.graph_exec_id,
graph_id=self.graph_id,
node_exec_id=self.node_exec_id,
node_id=self.node_id,
block_id=self.block_id,
data=self.input_data,
)
# --------------------- Model functions --------------------- #
@@ -342,7 +361,7 @@ async def get_graph_execution(
async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
starting_nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> GraphExecutionWithNodes:
@@ -369,7 +388,7 @@ async def create_graph_execution(
]
},
)
for node_id, node_input in nodes_input
for node_id, node_input in starting_nodes_input
]
},
userId=user_id,
@@ -469,7 +488,9 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
@@ -478,10 +499,7 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutio
},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecution.from_db(res)
return GraphExecution.from_db(res) if res else None
async def update_graph_execution_stats(
@@ -597,8 +615,9 @@ async def delete_graph_execution(
)
async def get_node_execution_results(
async def get_node_executions(
graph_exec_id: str,
node_id: str | None = None,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
@@ -606,6 +625,8 @@ async def get_node_execution_results(
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if node_id:
where_clause["agentNodeId"] = node_id
if block_ids:
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
@@ -662,20 +683,6 @@ async def get_latest_node_execution(
return NodeExecutionResult.from_db(execution)
async def get_incomplete_node_executions(
node_id: str, graph_eid: str
) -> list[NodeExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE,
)
return [NodeExecutionResult.from_db(execution) for execution in executions]
# ----------------- Execution Infrastructure ----------------- #
@@ -684,7 +691,7 @@ class GraphExecutionEntry(BaseModel):
graph_exec_id: str
graph_id: str
graph_version: int
start_node_execs: list["NodeExecutionEntry"]
node_credentials_input_map: Optional[dict[str, dict[str, CredentialsMetaInput]]]
class NodeExecutionEntry(BaseModel):
@@ -717,144 +724,6 @@ class ExecutionQueue(Generic[T]):
return self.queue.empty()
# ------------------- Execution Utilities -------------------- #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = mock.MockObject()
setattr(data[name], index, value)
return data
# --------------------- Event Bus --------------------- #

View File

@@ -1,7 +1,7 @@
import logging
import uuid
from collections import defaultdict
from typing import Any, Literal, Optional, Type, cast
from typing import Any, Literal, Optional, cast
import prisma
from prisma import Json
@@ -13,12 +13,19 @@ from prisma.types import (
AgentNodeCreateInput,
AgentNodeLinkCreateInput,
)
from pydantic import create_model
from pydantic.fields import computed_field
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.model import (
CredentialsField,
CredentialsFieldInfo,
CredentialsMetaInput,
is_credentials_field_name,
)
from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
@@ -165,6 +172,8 @@ class BaseGraph(BaseDbModel):
description: str
nodes: list[Node] = []
links: list[Link] = []
forked_from_id: str | None = None
forked_from_version: int | None = None
@computed_field
@property
@@ -190,14 +199,19 @@ class BaseGraph(BaseDbModel):
)
)
@computed_field
@property
def credentials_input_schema(self) -> dict[str, Any]:
return self._credentials_input_schema.jsonschema()
@staticmethod
def _generate_schema(
*props: tuple[Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input], dict],
*props: tuple[type[AgentInputBlock.Input] | type[AgentOutputBlock.Input], dict],
) -> dict[str, Any]:
schema = []
schema_fields: list[AgentInputBlock.Input | AgentOutputBlock.Input] = []
for type_class, input_default in props:
try:
schema.append(type_class(**input_default))
schema_fields.append(type_class(**input_default))
except Exception as e:
logger.warning(f"Invalid {type_class}: {input_default}, {e}")
@@ -217,9 +231,93 @@ class BaseGraph(BaseDbModel):
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in schema
for p in schema_fields
},
"required": [p.name for p in schema if p.value is None],
"required": [p.name for p in schema_fields if p.value is None],
}
@property
def _credentials_input_schema(self) -> type[BlockSchema]:
graph_credentials_inputs = self.aggregate_credentials_inputs()
logger.debug(
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
f"{graph_credentials_inputs}"
)
# Warn if same-provider credentials inputs can't be combined (= bad UX)
graph_cred_fields = list(graph_credentials_inputs.values())
for i, (field, keys) in enumerate(graph_cred_fields):
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
if field.provider != other_field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
logger.warning(
"Multiple combined credentials fields "
f"for provider {field.provider} "
f"on graph #{self.id} ({self.name}); "
f"fields: {field} <> {other_field};"
f"keys: {keys} <> {other_keys}."
)
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
agg_field_key: (
CredentialsMetaInput[
Literal[tuple(field_info.provider)], # type: ignore
Literal[tuple(field_info.supported_types)], # type: ignore
],
CredentialsField(
required_scopes=set(field_info.required_scopes or []),
discriminator=field_info.discriminator,
discriminator_mapping=field_info.discriminator_mapping,
),
)
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
}
return create_model(
self.name.replace(" ", "") + "CredentialsInputSchema",
__base__=BlockSchema,
**fields, # type: ignore
)
def aggregate_credentials_inputs(
self,
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
"""
Returns:
dict[aggregated_field_key, tuple(
CredentialsFieldInfo: A spec for one aggregated credentials field
set[(node_id, field_name)]: Node credentials fields that are
compatible with this aggregated field spec
)]
"""
return {
"_".join(sorted(agg_field_info.provider))
+ "_"
+ "_".join(sorted(agg_field_info.supported_types))
+ "_credentials": (agg_field_info, node_fields)
for agg_field_info, node_fields in CredentialsFieldInfo.combine(
*(
(
# Apply discrimination before aggregating credentials inputs
(
field_info.discriminate(
node.input_default[field_info.discriminator]
)
if (
field_info.discriminator
and node.input_default.get(field_info.discriminator)
)
else field_info
),
(node.id, field_name),
)
for node in self.nodes
for field_name, field_info in node.block.input_schema.get_credentials_fields_info().items()
)
)
}
@@ -313,15 +411,16 @@ class GraphModel(Graph):
@staticmethod
def _validate_graph(graph: BaseGraph, for_run: bool = False):
def is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def sanitize(name):
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
if sanitized_name.startswith("tools_^_"):
return sanitized_name.split("_^_")[0]
if is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
# Validate smart decision maker nodes
smart_decision_maker_nodes = set()
agent_nodes = set()
nodes_block = {
node.id: block
for node in graph.nodes
@@ -332,13 +431,6 @@ class GraphModel(Graph):
if (block := nodes_block.get(node.id)) is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
# Smart decision maker nodes
if block.block_type == BlockType.AI:
smart_decision_maker_nodes.add(node.id)
# Agent nodes
elif block.block_type == BlockType.AGENT:
agent_nodes.add(node.id)
input_links = defaultdict(list)
for link in graph.links:
@@ -353,16 +445,21 @@ class GraphModel(Graph):
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
)
for name in block.input_schema.get_required_fields():
input_schema = block.input_schema
for name in (required_fields := input_schema.get_required_fields()):
if (
name not in provided_inputs
# Webhook payload is passed in by ExecutionManager
and not (
name == "payload"
and block.block_type
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
)
# Checking availability of credentials is done by ExecutionManager
and name not in input_schema.get_credentials_fields()
# Validate only I/O nodes, or validate everything when executing
and (
for_run # Skip input completion validation, unless when executing.
for_run
or block.block_type
in [
BlockType.INPUT,
@@ -375,9 +472,18 @@ class GraphModel(Graph):
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
if (
block.block_type == BlockType.INPUT
and (input_key := node.input_default.get("name"))
and is_credentials_field_name(input_key)
):
raise ValueError(
f"Agent input node uses reserved name '{input_key}'; "
"'credentials' and `*_credentials` are reserved input names"
)
# Get input schema properties and check dependencies
input_schema = block.input_schema.model_fields
required_fields = block.input_schema.get_required_fields()
input_fields = input_schema.model_fields
def has_value(name):
return (
@@ -385,14 +491,21 @@ class GraphModel(Graph):
and name in node.input_default
and node.input_default[name] is not None
and str(node.input_default[name]).strip() != ""
) or (name in input_schema and input_schema[name].default is not None)
) or (name in input_fields and input_fields[name].default is not None)
# Validate dependencies between fields
for field_name, field_info in input_schema.items():
for field_name, field_info in input_fields.items():
# Apply input dependency validation only on run & field with depends_on
json_schema_extra = field_info.json_schema_extra or {}
dependencies = json_schema_extra.get("depends_on", [])
if not for_run or not dependencies:
if not (
for_run
and isinstance(json_schema_extra, dict)
and (
dependencies := cast(
list[str], json_schema_extra.get("depends_on", [])
)
)
):
continue
# Check if dependent field has value in input_default
@@ -445,7 +558,7 @@ class GraphModel(Graph):
if block.block_type not in [BlockType.AGENT]
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not name.startswith("tools_^_"):
if sanitized_name not in fields and not is_tool_pin(name):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
@@ -462,6 +575,8 @@ class GraphModel(Graph):
id=graph.id,
user_id=graph.userId if not for_export else "",
version=graph.version,
forked_from_id=graph.forkedFromId,
forked_from_version=graph.forkedFromVersion,
is_active=graph.isActive,
name=graph.name or "",
description=graph.description or "",
@@ -621,6 +736,58 @@ async def get_graph(
return GraphModel.from_db(graph, for_export)
async def get_graph_as_admin(
graph_id: str,
version: int | None = None,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
"""
Intentionally parallels the get_graph but should only be used for admin tasks, because can return any graph that's been submitted
Retrieves a graph from the DB.
Defaults to the version with `is_active` if `version` is not passed.
Returns `None` if the record is not found.
"""
logger.warning(f"Getting {graph_id=} {version=} as ADMIN {user_id=} {for_export=}")
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
# For access, the graph must be owned by the user or listed in the store
if graph is None or (
graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where={
"agentGraphId": graph_id,
"agentGraphVersion": version or graph.version,
}
)
)
):
return None
if for_export:
sub_graphs = await get_sub_graphs(graph)
return GraphModel.from_db(
graph=graph,
sub_graphs=sub_graphs,
for_export=for_export,
)
return GraphModel.from_db(graph, for_export)
async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
"""
Iteratively fetches all sub-graphs of a given graph, and flattens them into a list.
@@ -739,6 +906,27 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
"""
Forks a graph by copying it and all its nodes and links to a new graph.
"""
async with transaction() as tx:
graph = await get_graph(graph_id, graph_version, user_id, True)
if not graph:
raise ValueError(f"Graph {graph_id} v{graph_version} not found")
# Set forked from ID and version as itself as it's about ot be copied
graph.forked_from_id = graph.id
graph.forked_from_version = graph.version
graph.name = f"{graph.name} (copy)"
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph.validate_graph(for_run=False)
await __create_graph(tx, graph, user_id)
return graph
async def __create_graph(tx, graph: Graph, user_id: str):
graphs = [graph] + graph.sub_graphs
@@ -751,6 +939,8 @@ async def __create_graph(tx, graph: Graph, user_id: str):
description=graph.description,
isActive=graph.is_active,
userId=user_id,
forkedFromId=graph.forked_from_id,
forkedFromVersion=graph.forked_from_version,
)
for graph in graphs
]
@@ -914,24 +1104,24 @@ async def migrate_llm_models(migrate_to: LlmModel):
if field.annotation == LlmModel:
llm_model_fields[block.id] = field_name
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
# Update each block
for id, path in llm_model_fields.items():
# Convert enum values to a list of strings for the SQL query
enum_values = [v.value for v in LlmModel.__members__.values()]
escaped_enum_values = repr(tuple(enum_values)) # hack but works
query = f"""
UPDATE "AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, $2, true)
UPDATE platform."AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
WHERE "agentBlockId" = $3
AND "constantInput" ? $4
AND "constantInput"->>$4 NOT IN {escaped_enum_values}
AND "constantInput" ? ($4)::text
AND "constantInput"->>($4)::text NOT IN {escaped_enum_values}
"""
await db.execute_raw(
query, # type: ignore - is supposed to be LiteralString
"{" + path + "}",
f'"{migrate_to.value}"',
[path],
migrate_to.value,
id,
path,
)

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import base64
import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
@@ -12,6 +14,7 @@ from typing import (
Generic,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
get_args,
@@ -300,9 +303,7 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
field_schema = model.jsonschema()["properties"][field_name]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
@@ -328,14 +329,90 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
)
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
provider: frozenset[CP] = Field(..., alias="credentials_provider")
supported_types: frozenset[CT] = Field(..., alias="credentials_types")
required_scopes: Optional[frozenset[str]] = Field(None, alias="credentials_scopes")
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
@classmethod
def combine(
cls, *fields: tuple[CredentialsFieldInfo[CP, CT], T]
) -> Sequence[tuple[CredentialsFieldInfo[CP, CT], set[T]]]:
"""
Combines multiple CredentialsFieldInfo objects into as few as possible.
Rules:
- Items can only be combined if they have the same supported credentials types
and the same supported providers.
- When combining items, the `required_scopes` of the result is a join
of the `required_scopes` of the original items.
Params:
*fields: (CredentialsFieldInfo, key) objects to group and combine
Returns:
A sequence of tuples containing combined CredentialsFieldInfo objects and
the set of keys of the respective original items that were grouped together.
"""
if not fields:
return []
# Group fields by their provider and supported_types
grouped_fields: defaultdict[
tuple[frozenset[CP], frozenset[CT]],
list[tuple[T, CredentialsFieldInfo[CP, CT]]],
] = defaultdict(list)
for field, key in fields:
group_key = (frozenset(field.provider), frozenset(field.supported_types))
grouped_fields[group_key].append((key, field))
# Combine fields within each group
result: list[tuple[CredentialsFieldInfo[CP, CT], set[T]]] = []
for group in grouped_fields.values():
# Start with the first field in the group
_, combined = group[0]
# Track the keys that were combined
combined_keys = {key for key, _ in group}
# Combine required_scopes from all fields in the group
all_scopes = set()
for _, field in group:
if field.required_scopes:
all_scopes.update(field.required_scopes)
# Create a new combined field
result.append(
(
CredentialsFieldInfo[CP, CT](
credentials_provider=combined.provider,
credentials_types=combined.supported_types,
credentials_scopes=frozenset(all_scopes) or None,
discriminator=combined.discriminator,
discriminator_mapping=combined.discriminator_mapping,
),
combined_keys,
)
)
return result
def discriminate(self, discriminator_value: Any) -> CredentialsFieldInfo:
if not (self.discriminator and self.discriminator_mapping):
return self
discriminator_value = self.discriminator_mapping[discriminator_value]
return CredentialsFieldInfo(
credentials_provider=frozenset([discriminator_value]),
credentials_types=self.supported_types,
credentials_scopes=self.required_scopes,
)
def CredentialsField(
required_scopes: set[str] = set(),
@@ -373,6 +450,12 @@ class ContributorDetails(BaseModel):
name: str = Field(title="Name", description="The name of the contributor.")
class TopUpType(enum.Enum):
AUTO = "AUTO"
MANUAL = "MANUAL"
UNCATEGORIZED = "UNCATEGORIZED"
class AutoTopUpConfig(BaseModel):
amount: int
"""Amount of credits to top up."""
@@ -385,12 +468,18 @@ class UserTransaction(BaseModel):
transaction_time: datetime = datetime.min.replace(tzinfo=timezone.utc)
transaction_type: CreditTransactionType = CreditTransactionType.USAGE
amount: int = 0
balance: int = 0
running_balance: int = 0
current_balance: int = 0
description: str | None = None
usage_graph_id: str | None = None
usage_execution_id: str | None = None
usage_node_count: int = 0
usage_start_time: datetime = datetime.max.replace(tzinfo=timezone.utc)
user_id: str
user_email: str | None = None
reason: str | None = None
admin_email: str | None = None
extra_data: str | None = None
class TransactionHistory(BaseModel):

View File

@@ -8,7 +8,9 @@ from prisma.enums import OnboardingStep
from prisma.models import UserOnboarding
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
from backend.data import db
from backend.data.block import get_blocks
from backend.data.credit import get_user_credit_model
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.server.v2.store.model import StoreAgentDetails
@@ -24,14 +26,19 @@ REASON_MAPPING: dict[str, list[str]] = {
POINTS_AGENT_COUNT = 50 # Number of agents to calculate points for
MIN_AGENT_COUNT = 2 # Minimum number of marketplace agents to enable onboarding
user_credit = get_user_credit_model()
class UserOnboardingUpdate(pydantic.BaseModel):
completedSteps: Optional[list[OnboardingStep]] = None
notificationDot: Optional[bool] = None
notified: Optional[list[OnboardingStep]] = None
usageReason: Optional[str] = None
integrations: Optional[list[str]] = None
otherIntegrations: Optional[str] = None
selectedStoreListingVersionId: Optional[str] = None
agentInput: Optional[dict[str, Any]] = None
onboardingAgentExecutionId: Optional[str] = None
async def get_user_onboarding(user_id: str):
@@ -48,6 +55,20 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update: UserOnboardingUpdateInput = {}
if data.completedSteps is not None:
update["completedSteps"] = list(set(data.completedSteps))
for step in (
OnboardingStep.AGENT_NEW_RUN,
OnboardingStep.GET_RESULTS,
OnboardingStep.MARKETPLACE_ADD_AGENT,
OnboardingStep.MARKETPLACE_RUN_AGENT,
OnboardingStep.BUILDER_SAVE_AGENT,
OnboardingStep.BUILDER_RUN_AGENT,
):
if step in data.completedSteps:
await reward_user(user_id, step)
if data.notificationDot is not None:
update["notificationDot"] = data.notificationDot
if data.notified is not None:
update["notified"] = list(set(data.notified))
if data.usageReason is not None:
update["usageReason"] = data.usageReason
if data.integrations is not None:
@@ -58,6 +79,8 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
update["selectedStoreListingVersionId"] = data.selectedStoreListingVersionId
if data.agentInput is not None:
update["agentInput"] = Json(data.agentInput)
if data.onboardingAgentExecutionId is not None:
update["onboardingAgentExecutionId"] = data.onboardingAgentExecutionId
return await UserOnboarding.prisma().upsert(
where={"userId": user_id},
@@ -68,6 +91,45 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
)
async def reward_user(user_id: str, step: OnboardingStep):
async with db.locked_transaction(f"usr_trx_{user_id}-reward"):
reward = 0
match step:
# Reward user when they clicked New Run during onboarding
# This is because they need credits before scheduling a run (next step)
case OnboardingStep.AGENT_NEW_RUN:
reward = 300
case OnboardingStep.GET_RESULTS:
reward = 300
case OnboardingStep.MARKETPLACE_ADD_AGENT:
reward = 100
case OnboardingStep.MARKETPLACE_RUN_AGENT:
reward = 100
case OnboardingStep.BUILDER_SAVE_AGENT:
reward = 100
case OnboardingStep.BUILDER_RUN_AGENT:
reward = 100
if reward == 0:
return
onboarding = await get_user_onboarding(user_id)
# Skip if already rewarded
if step in onboarding.rewardedFor:
return
onboarding.rewardedFor.append(step)
await user_credit.onboarding_reward(user_id, reward, step)
await UserOnboarding.prisma().update(
where={"userId": user_id},
data={
"completedSteps": list(set(onboarding.completedSteps + [step])),
"rewardedFor": onboarding.rewardedFor,
},
)
def clean_and_split(text: str) -> list[str]:
"""
Removes all special characters from a string, truncates it to 100 characters,

View File

@@ -4,10 +4,18 @@ from enum import Enum
from typing import Awaitable, Optional
import aio_pika
import aio_pika.exceptions as aio_ex
import pika
import pika.adapters.blocking_connection
from pika.exceptions import AMQPError
from pika.spec import BasicProperties
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from backend.util.retry import conn_retry
from backend.util.settings import Settings
@@ -161,6 +169,12 @@ class SyncRabbitMQ(RabbitMQBase):
routing_key=queue.routing_key or queue.name,
)
@retry(
retry=retry_if_exception_type((AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
def publish_message(
self,
routing_key: str,
@@ -258,6 +272,12 @@ class AsyncRabbitMQ(RabbitMQBase):
exchange, routing_key=queue.routing_key or queue.name
)
@retry(
retry=retry_if_exception_type((aio_ex.AMQPError, ConnectionError)),
wait=wait_random_exponential(multiplier=1, max=5),
stop=stop_after_attempt(5),
reraise=True,
)
async def publish_message(
self,
routing_key: str,

View File

@@ -1,9 +1,10 @@
from .database import DatabaseManager
from .database import DatabaseManager, DatabaseManagerClient
from .manager import ExecutionManager
from .scheduler import Scheduler
__all__ = [
"DatabaseManager",
"DatabaseManagerClient",
"ExecutionManager",
"Scheduler",
]

View File

@@ -1,16 +1,14 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db, redis
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecution,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
get_graph_execution_meta,
get_latest_node_execution,
get_node_execution_results,
get_node_executions,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -42,12 +40,14 @@ from backend.data.user import (
update_user_integrations,
update_user_metadata,
)
from backend.util.service import AppService, expose, exposed_run_and_wait
from backend.util.service import AppService, AppServiceClient, endpoint_to_sync, expose
from backend.util.settings import Config
config = Config()
_user_credit_model = get_user_credit_model()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
async def _spend_credits(
@@ -56,22 +56,19 @@ async def _spend_credits(
return await _user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def __init__(self):
super().__init__()
self.execution_event_bus = RedisExecutionEventBus()
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
super().run_service()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
@@ -79,64 +76,113 @@ class DatabaseManager(AppService):
def get_port(cls) -> int:
return config.database_api_port
@expose
def send_execution_update(
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
@staticmethod
def _(
f: Callable[P, R], name: str | None = None
) -> Callable[Concatenate[object, P], R]:
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(
get_incomplete_node_executions
)
get_latest_node_execution = exposed_run_and_wait(get_latest_node_execution)
update_node_execution_status = exposed_run_and_wait(update_node_execution_status)
update_node_execution_status_batch = exposed_run_and_wait(
update_node_execution_status_batch
)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)
update_graph_execution_stats = exposed_run_and_wait(update_graph_execution_stats)
update_node_execution_stats = exposed_run_and_wait(update_node_execution_stats)
upsert_execution_input = exposed_run_and_wait(upsert_execution_input)
upsert_execution_output = exposed_run_and_wait(upsert_execution_output)
get_graph_execution = _(get_graph_execution)
get_graph_execution_meta = _(get_graph_execution_meta)
create_graph_execution = _(create_graph_execution)
get_node_executions = _(get_node_executions)
get_latest_node_execution = _(get_latest_node_execution)
update_node_execution_status = _(update_node_execution_status)
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
# Graphs
get_node = exposed_run_and_wait(get_node)
get_graph = exposed_run_and_wait(get_graph)
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
get_node = _(get_node)
get_graph = _(get_graph)
get_connected_output_nodes = _(get_connected_output_nodes)
get_graph_metadata = _(get_graph_metadata)
# Credits
spend_credits = exposed_run_and_wait(_spend_credits)
spend_credits = _(_spend_credits, name="spend_credits")
get_credits = _(_get_credits, name="get_credits")
# User + User Metadata + User Integrations
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
get_user_integrations = exposed_run_and_wait(get_user_integrations)
update_user_integrations = exposed_run_and_wait(update_user_integrations)
get_user_metadata = _(get_user_metadata)
update_user_metadata = _(update_user_metadata)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = exposed_run_and_wait(
get_active_user_ids_in_timerange
)
get_user_email_by_id = exposed_run_and_wait(get_user_email_by_id)
get_user_email_verification = exposed_run_and_wait(get_user_email_verification)
get_user_notification_preference = exposed_run_and_wait(
get_user_notification_preference
)
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = exposed_run_and_wait(
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_all_batches_by_type = exposed_run_and_wait(get_all_batches_by_type)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = exposed_run_and_wait(
empty_user_notification_batch = _(empty_user_notification_batch)
get_all_batches_by_type = _(get_all_batches_by_type)
get_user_notification_batch = _(get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
get_user_notification_oldest_message_in_batch
)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
_ = endpoint_to_sync
@classmethod
def get_service_type(cls):
return DatabaseManager
# Executions
get_graph_execution = _(d.get_graph_execution)
get_graph_execution_meta = _(d.get_graph_execution_meta)
create_graph_execution = _(d.create_graph_execution)
get_node_executions = _(d.get_node_executions)
get_latest_node_execution = _(d.get_latest_node_execution)
update_node_execution_status = _(d.update_node_execution_status)
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
# Graphs
get_node = _(d.get_node)
get_graph = _(d.get_graph)
get_connected_output_nodes = _(d.get_connected_output_nodes)
get_graph_metadata = _(d.get_graph_metadata)
# Credits
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)
update_user_metadata = _(d.update_user_metadata)
get_user_integrations = _(d.get_user_integrations)
update_user_integrations = _(d.update_user_integrations)
# User Comms - async
get_active_user_ids_in_timerange = _(d.get_active_user_ids_in_timerange)
get_user_email_by_id = _(d.get_user_email_by_id)
get_user_email_verification = _(d.get_user_email_verification)
get_user_notification_preference = _(d.get_user_notification_preference)
# Notifications - async
create_or_add_to_user_notification_batch = _(
d.create_or_add_to_user_notification_batch
)
empty_user_notification_batch = _(d.empty_user_notification_batch)
get_all_batches_by_type = _(d.get_all_batches_by_type)
get_user_notification_batch = _(d.get_user_notification_batch)
get_user_notification_oldest_message_in_batch = _(
d.get_user_notification_oldest_message_in_batch
)

View File

@@ -8,8 +8,10 @@ import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from redis.lock import Lock as RedisLock
from backend.blocks.io import AgentOutputBlock
@@ -20,57 +22,65 @@ from backend.data.notifications import (
NotificationEventDTO,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from backend.executor import DatabaseManagerClient
from backend.notifications.notifications import NotificationManagerClient
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
from backend.data import redis
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block import BlockData, BlockInput, BlockSchema, get_block
from backend.data.credit import UsageTransactionMetadata
from backend.data.execution import (
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
merge_execution_input,
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.data.graph import Link, Node
from backend.executor.utils import (
UsageTransactionMetadata,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_QUEUE_NAME,
CancelExecutionEvent,
block_usage_cost,
execution_usage_cost,
get_execution_event_bus,
get_execution_queue,
parse_execution_output,
validate_exec,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
from backend.util.file import clean_exec_files
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
from backend.util.service import (
AppService,
close_service_client,
expose,
get_service_client,
)
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
from backend.util.type import convert
logger = logging.getLogger(__name__)
settings = Settings()
active_runs_gauge = Gauge(
"execution_manager_active_runs", "Number of active graph runs"
)
pool_size_gauge = Gauge(
"execution_manager_pool_size", "Maximum number of graph workers"
)
utilization_gauge = Gauge(
"execution_manager_utilization_ratio",
"Ratio of active graph runs to max graph workers",
)
class LogMetadata:
def __init__(
@@ -91,7 +101,7 @@ class LogMetadata:
"node_id": node_id,
"block_name": block_name,
}
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]"
def info(self, msg: str, **extra):
msg = self._wrap(msg, **extra)
@@ -125,7 +135,7 @@ ExecutionStream = Generator[NodeExecutionEntry, None, None]
def execute_node(
db_client: "DatabaseManager",
db_client: "DatabaseManagerClient",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: NodeExecutionStats | None = None,
@@ -152,7 +162,7 @@ def execute_node(
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return exec_update
node = db_client.get_node(node_id)
@@ -192,7 +202,7 @@ def execute_node(
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
update_execution_status(ExecutionStatus.RUNNING)
# Inject extra execution arguments for the blocks via kwargs
@@ -223,7 +233,7 @@ def execute_node(
):
output_data = json.convert_pydantic_to_json(output_data)
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
log_metadata.debug("Node produced output", **{output_name: output_data})
push_output(output_name, output_data)
outputs[output_name] = output_data
for execution in _enqueue_next_nodes(
@@ -258,7 +268,7 @@ def execute_node(
raise e
finally:
# Ensure credentials are released even if execution fails
if creds_lock and creds_lock.locked():
if creds_lock and creds_lock.locked() and creds_lock.owned():
try:
creds_lock.release()
except Exception as e:
@@ -274,7 +284,7 @@ def execute_node(
def _enqueue_next_nodes(
db_client: "DatabaseManager",
db_client: "DatabaseManagerClient",
node: Node,
output: BlockData,
user_id: str,
@@ -288,7 +298,7 @@ def _enqueue_next_nodes(
exec_update = db_client.update_node_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
@@ -363,8 +373,10 @@ def _enqueue_next_nodes(
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in db_client.get_incomplete_node_executions(
next_node_id, graph_exec_id
for iexec in db_client.get_node_executions(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.INCOMPLETE],
):
idata = iexec.input_data
ineid = iexec.node_exec_id
@@ -400,60 +412,6 @@ def _enqueue_next_nodes(
]
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
class Executor:
"""
This class contains event handlers for the process pool executor events.
@@ -480,6 +438,7 @@ class Executor:
"""
@classmethod
@func_retry
def on_node_executor_start(cls):
configure_logging()
set_service_name("NodeExecutor")
@@ -490,36 +449,28 @@ class Executor:
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
atexit.register(cls.on_node_executor_stop)
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
@classmethod
def on_node_executor_stop(cls):
def on_node_executor_stop(cls, log=logger.info):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
cls.db_client.close()
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
sys.exit(0)
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
cls.on_node_executor_stop(log=llprint)
@classmethod
@error_logged
@@ -585,6 +536,7 @@ class Executor:
stats.error = e
@classmethod
@func_retry
def on_graph_executor_start(cls):
configure_logging()
set_service_name("GraphExecutor")
@@ -594,21 +546,7 @@ class Executor:
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"{prefix} ✅ Finished cleanup")
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
@classmethod
def _init_node_executor_pool(cls):
@@ -630,10 +568,35 @@ class Executor:
node_eid="*",
block_name="-",
)
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
exec_meta = cls.db_client.get_graph_execution_meta(
user_id=graph_exec.user_id,
execution_id=graph_exec.graph_exec_id,
)
cls.db_client.send_execution_update(exec_meta)
if exec_meta is None:
log_metadata.warning(
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
)
return
if exec_meta.status == ExecutionStatus.QUEUED:
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
exec_meta.status = ExecutionStatus.RUNNING
send_execution_update(
cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
)
elif exec_meta.status == ExecutionStatus.RUNNING:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
)
else:
log_metadata.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
)
return
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
@@ -646,7 +609,7 @@ class Executor:
status=status,
stats=exec_stats,
):
cls.db_client.send_execution_update(graph_exec_result)
send_execution_update(graph_exec_result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@@ -656,11 +619,11 @@ class Executor:
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
) -> int:
):
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return execution_count
return
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
@@ -675,11 +638,12 @@ class Executor:
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
execution_stats.cost += cost
cost, execution_count = execution_usage_cost(execution_count)
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
@@ -688,15 +652,14 @@ class Executor:
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": execution_count,
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
execution_stats.cost += cost
return execution_count
@classmethod
@time_measured
def _on_graph_execution(
@@ -711,7 +674,6 @@ class Executor:
ExecutionStatus: The final status of the graph execution.
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
execution_stats = GraphExecutionStats()
execution_status = ExecutionStatus.RUNNING
error = None
@@ -733,11 +695,21 @@ class Executor:
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
balance=0,
amount=1,
)
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in cls.db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[ExecutionStatus.RUNNING, ExecutionStatus.QUEUED],
):
queue.add(node_exec.to_node_execution_entry())
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
@@ -759,7 +731,7 @@ class Executor:
status=execution_status,
stats=execution_stats,
):
cls.db_client.send_execution_update(_graph_exec)
send_execution_update(_graph_exec)
else:
logger.error(
"Callback for "
@@ -776,10 +748,10 @@ class Executor:
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
exec_data = queue.get()
queued_node_exec = queue.get()
# Avoid parallel execution of the same node.
execution = running_executions.get(exec_data.node_id)
execution = running_executions.get(queued_node_exec.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
@@ -788,18 +760,18 @@ class Executor:
execution.wait()
log_metadata.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
execution_count=exec_cost_counter + 1,
cls._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(graph_exec.user_id),
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
node_exec_id = exec_data.node_exec_id
node_exec_id = queued_node_exec.node_exec_id
cls.db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="error",
@@ -810,7 +782,7 @@ class Executor:
exec_update = cls.db_client.update_node_execution_status(
node_exec_id, execution_status
)
cls.db_client.send_execution_update(exec_update)
send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
@@ -820,10 +792,23 @@ class Executor:
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
# Add credentials input overrides
node_id = queued_node_exec.node_id
if (node_creds_map := graph_exec.node_credentials_input_map) and (
node_field_creds_map := node_creds_map.get(node_id)
):
queued_node_exec.data.update(
{
field_name: creds_meta.model_dump()
for field_name, creds_meta in node_field_creds_map.items()
}
)
# Initiate node execution
running_executions[queued_node_exec.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=make_exec_callback(exec_data),
(queue, queued_node_exec),
callback=make_exec_callback(queued_node_exec),
)
# Avoid terminating graph execution when some nodes are still running.
@@ -843,24 +828,21 @@ class Executor:
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
execution_status = ExecutionStatus.COMPLETED
except Exception as e:
error = e
finally:
if error:
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
finally:
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
clean_exec_files(graph_exec.graph_exec_id)
return execution_stats, execution_status, error
@classmethod
@@ -872,7 +854,7 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_node_execution_results(
outputs = cls.db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
@@ -927,22 +909,31 @@ class Executor:
)
class ExecutionManager(AppService):
class ExecutionManager(AppProcess):
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.running = True
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
atexit.register(self._on_cleanup)
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run(self):
pool_size_gauge.set(self.pool_size)
active_runs_gauge.set(0)
utilization_gauge.set(0)
def run_service(self):
from backend.integrations.credentials_store import IntegrationCredentialsStore
self.credentials_store = IntegrationCredentialsStore()
self.metrics_server = threading.Thread(
target=start_http_server,
args=(settings.config.execution_manager_port,),
daemon=True,
)
self.metrics_server.start()
logger.info(f"[{self.service_name}] Starting execution manager...")
self._run()
def _run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
@@ -952,220 +943,174 @@ class ExecutionManager(AppService):
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
sync_manager = multiprocessing.Manager()
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id, None)
cancel_client = SyncRabbitMQ(create_execution_queue_config())
cancel_client.connect()
cancel_channel = cancel_client.get_channel()
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
threading.Thread(
target=lambda: (
cancel_channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
),
cancel_channel.start_consuming(),
),
daemon=True,
).start()
run_client = SyncRabbitMQ(create_execution_queue_config())
run_client.connect()
run_channel = run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
run_channel.basic_consume(
queue=GRAPH_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
)
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
run_channel.start_consuming()
def _handle_cancel_message(
self,
channel: BlockingChannel,
method: Basic.Deliver,
properties: BasicProperties,
body: bytes,
):
"""
Called whenever we receive a CANCEL message from the queue.
(With auto_ack=True, message is considered 'acked' automatically.)
"""
try:
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
except Exception as e:
logger.exception(f"Error handling cancel message: {e}")
def _handle_run_message(
self,
channel: BlockingChannel,
method: Basic.Deliver,
properties: BasicProperties,
body: bytes,
):
delivery_tag = method.delivery_tag
try:
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(f"[{self.service_name}] Could not parse run message: {e}")
channel.basic_nack(delivery_tag, requeue=False)
return
graph_exec_id = graph_exec_entry.graph_exec_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
)
if graph_exec_id in self.active_graph_runs:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
channel.basic_nack(delivery_tag, requeue=False)
return
cancel_event = multiprocessing.Manager().Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_entry, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
active_runs_gauge.set(len(self.active_graph_runs))
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
def _on_run_done(f: Future):
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
try:
self.active_graph_runs.pop(graph_exec_id, None)
active_runs_gauge.set(len(self.active_graph_runs))
utilization_gauge.set(len(self.active_graph_runs) / self.pool_size)
if f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {f.exception()}"
)
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=False)
)
else:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
)
except Exception as e:
logger.error(f"[{self.service_name}] Error acknowledging message: {e}")
future.add_done_callback(_on_run_done)
def cleanup(self):
super().cleanup()
self._on_cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
def _on_sigterm(self):
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
self._on_cleanup(log=llprint)
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
def _on_cleanup(self, log=logger.info):
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
log(f"{prefix} ⏳ Shutting down service loop...")
self.running = False
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
get_execution_queue().get_channel().stop_consuming()
if hasattr(self, "executor"):
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
self.executor.shutdown(cancel_futures=False, wait=True)
log(f"{prefix} ⏳ Disconnecting Redis...")
redis.disconnect()
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
@expose
def add_execution(
self,
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: Optional[int] = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise ValueError(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)
self._validate_node_input_credentials(graph, user_id)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in data:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": data[webhook_payload_key]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
self.db_client.send_execution_update(graph_exec)
graph_exec_entry = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec.id,
start_node_execs=[
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
data=node_exec.input_data,
)
for node_exec in graph_exec.node_executions
],
)
self.queue.add(graph_exec_entry)
return graph_exec_entry
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_node_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
log(f"{prefix} ✅ Finished GraphExec cleanup")
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_notification_service() -> "NotificationManager":
from backend.notifications import NotificationManager
def get_notification_service() -> "NotificationManagerClient":
from backend.notifications import NotificationManagerClient
return get_service_client(NotificationManager)
return get_service_client(NotificationManagerClient)
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return
return get_execution_event_bus().publish(entry)
@contextmanager
@@ -1175,14 +1120,26 @@ def synchronized(key: str, timeout: int = 60):
lock.acquire()
yield
finally:
if lock.locked():
if lock.locked() and lock.owned():
lock.release()
def increment_execution_count(user_id: str) -> int:
"""
Increment the execution count for a given user,
this will be used to charge the user for the execution cost.
"""
r = redis.get_redis()
k = f"uec:{user_id}" # User Execution Count global key
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())
os.write(sys.stdout.fileno(), (message + "\n").encode())

View File

@@ -16,9 +16,15 @@ from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from backend.data.block import BlockInput
from backend.executor.manager import ExecutionManager
from backend.notifications.notifications import NotificationManager
from backend.util.service import AppService, expose, get_service_client
from backend.executor import utils as execution_utils
from backend.notifications.notifications import NotificationManagerClient
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Config
@@ -57,25 +63,18 @@ def job_listener(event):
log(f"Job {event.job_id} completed successfully.")
@thread_cached
def get_execution_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
@thread_cached
def get_notification_client():
from backend.notifications import NotificationManager
return get_service_client(NotificationManager)
return get_service_client(NotificationManagerClient)
def execute_graph(**kwargs):
args = ExecutionJobArgs(**kwargs)
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
execution_utils.add_graph_execution(
graph_id=args.graph_id,
data=args.input_data,
inputs=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
@@ -164,19 +163,9 @@ class Scheduler(AppService):
def db_pool_size(cls) -> int:
return config.scheduler_db_pool_size
@property
@thread_cached
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@property
@thread_cached
def notification_client(self) -> NotificationManager:
return get_service_client(NotificationManager)
def run_service(self):
load_dotenv()
db_schema, db_url = _extract_schema_from_url(os.getenv("DATABASE_URL"))
db_schema, db_url = _extract_schema_from_url(os.getenv("DIRECT_URL"))
self.scheduler = BlockingScheduler(
jobstores={
Jobstores.EXECUTION.value: SQLAlchemyJobStore(
@@ -310,3 +299,15 @@ class Scheduler(AppService):
),
job,
)
class SchedulerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return Scheduler
add_execution_schedule = endpoint_to_async(Scheduler.add_execution_schedule)
delete_schedule = endpoint_to_async(Scheduler.delete_schedule)
get_execution_schedules = endpoint_to_async(Scheduler.get_execution_schedules)
add_batched_notification_schedule = Scheduler.add_batched_notification_schedule
add_weekly_notification_schedule = Scheduler.add_weekly_notification_schedule

View File

@@ -1,38 +1,113 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block import (
Block,
BlockData,
BlockInput,
BlockSchema,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionStats,
GraphExecutionWithNodes,
RedisExecutionEventBus,
create_graph_execution,
update_graph_execution_stats,
update_node_execution_status_batch,
)
from backend.data.graph import GraphModel, Node, get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.util.exceptions import NotFoundError
from backend.util.mock import MockObject
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.type import convert
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
config = Config()
logger = logging.getLogger(__name__)
# ============ Resource Helpers ============ #
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
@thread_cached
def get_execution_event_bus() -> RedisExecutionEventBus:
return RedisExecutionEventBus()
@thread_cached
def get_async_execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
@thread_cached
def get_execution_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_execution_queue_config())
client.connect()
return client
@thread_cached
async def get_async_execution_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def get_integration_credentials_store() -> "IntegrationCredentialsStore":
from backend.integrations.credentials_store import IntegrationCredentialsStore
return IntegrationCredentialsStore()
@thread_cached
def get_db_client() -> "DatabaseManagerClient":
from backend.executor import DatabaseManagerClient
return get_service_client(DatabaseManagerClient)
# ============ Execution Cost Helpers ============ #
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Calculate the cost of executing a graph based on the current number of node executions.
Args:
execution_count: Number of executions
execution_count: Number of node executions
Returns:
Tuple of cost amount and remaining execution count
Tuple of cost amount and the number of execution count that is included in the cost.
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
(
config.execution_cost_per_threshold
if execution_count % config.execution_cost_count_threshold == 0
else 0
),
config.execution_cost_count_threshold,
)
@@ -95,3 +170,571 @@ def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bo
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
# ============ Execution Input Helpers ============ #
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: BlockData, name: str) -> Any | None:
"""
Extracts partial output data by name from a given BlockData.
The function supports extracting data from lists, dictionaries, and objects
using specific naming conventions:
- For lists: <output_name>_$_<index>
- For dictionaries: <output_name>_#_<key>
- For objects: <output_name>_@_<attribute>
Args:
output (BlockData): A tuple containing the output name and data.
name (str): The name used to extract specific data from the output.
Returns:
Any | None: The extracted data if found, otherwise None.
Examples:
>>> output = ("result", [10, 20, 30])
>>> parse_execution_output(output, "result_$_1")
20
>>> output = ("config", {"key1": "value1", "key2": "value2"})
>>> parse_execution_output(output, "config_#_key1")
'value1'
>>> class Sample:
... attr1 = "value1"
... attr2 = "value2"
>>> output = ("object", Sample())
>>> parse_execution_output(output, "object_@_attr1")
'value1'
"""
output_name, output_data = output
if name == output_name:
return output_data
if name.startswith(f"{output_name}{LIST_SPLIT}"):
index = int(name.split(LIST_SPLIT)[1])
if not isinstance(output_data, list) or len(output_data) <= index:
return None
return output_data[int(name.split(LIST_SPLIT)[1])]
if name.startswith(f"{output_name}{DICT_SPLIT}"):
index = name.split(DICT_SPLIT)[1]
if not isinstance(output_data, dict) or index not in output_data:
return None
return output_data[index]
if name.startswith(f"{output_name}{OBJC_SPLIT}"):
index = name.split(OBJC_SPLIT)[1]
if isinstance(output_data, object) and hasattr(output_data, index):
return getattr(output_data, index)
return None
return None
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message
return data, node_block.name
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
This function processes input keys that follow specific patterns to merge them into a unified structure:
- `<input_name>_$_<index>` for list inputs.
- `<input_name>_#_<index>` for dictionary inputs.
- `<input_name>_@_<index>` for object inputs.
Args:
data (BlockInput): A dictionary containing input keys and their corresponding values.
Returns:
BlockInput: A dictionary with merged inputs.
Raises:
ValueError: If a list index is not an integer.
Examples:
>>> data = {
... "list_$_0": "a",
... "list_$_1": "b",
... "dict_#_key1": "value1",
... "dict_#_key2": "value2",
... "object_@_attr1": "value1",
... "object_@_attr2": "value2"
... }
>>> merge_execution_input(data)
{
"list": ["a", "b"],
"dict": {"key1": "value1", "key2": "value2"},
"object": <MockObject attr1="value1" attr2="value2">
}
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
for key, value in items:
if LIST_SPLIT not in key:
continue
name, index = key.split(LIST_SPLIT)
if not index.isdigit():
raise ValueError(f"Invalid key: {key}, #{index} index must be an integer.")
data[name] = data.get(name, [])
if int(index) >= len(data[name]):
# Pad list with empty string on missing indices.
data[name].extend([""] * (int(index) - len(data[name]) + 1))
data[name][int(index)] = value
# Merge all input with <input_name>_#_<index> into a single dict.
for key, value in items:
if DICT_SPLIT not in key:
continue
name, index = key.split(DICT_SPLIT)
data[name] = data.get(name, {})
data[name][index] = value
# Merge all input with <input_name>_@_<index> into a single object.
for key, value in items:
if OBJC_SPLIT not in key:
continue
name, index = key.split(OBJC_SPLIT)
if name not in data or not isinstance(data[name], object):
data[name] = MockObject()
setattr(data[name], index, value)
return data
def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
):
"""Checks all credentials for all nodes of the graph"""
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = cast(
type[BlockSchema], block.input_schema
).get_credentials_fields()
if not credentials_fields:
continue
for field_name, credentials_meta_type in credentials_fields.items():
if (
node_credentials_input_map
and (node_credentials_inputs := node_credentials_input_map.get(node.id))
and field_name in node_credentials_inputs
):
credentials_meta = node_credentials_input_map[node.id][field_name]
elif field_name in node.input_default:
credentials_meta = credentials_meta_type.model_validate(
node.input_default[field_name]
)
else:
raise ValueError(
f"Credentials absent for {block.name} node #{node.id} "
f"input '{field_name}'"
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
raise ValueError(
f"Unknown credentials #{credentials_meta.id} "
f"for node #{node.id} input '{field_name}'"
)
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
raise ValueError(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch"
)
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: dict[str, CredentialsMetaInput],
) -> dict[str, dict[str, CredentialsMetaInput]]:
"""
Maps credentials for an execution to the correct nodes.
Params:
graph: The graph to be executed.
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaInput]]: Node credentials input map.
"""
result: dict[str, dict[str, CredentialsMetaInput]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue
# Use passed-in credentials for all compatible node input fields
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[graph_input_name]
return result
def construct_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
node_credentials_input_map: Optional[
dict[str, dict[str, CredentialsMetaInput]]
] = None,
) -> list[tuple[str, BlockInput]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
against the schema, and resolves dynamic input pins into a single list,
dictionary, or object.
Args:
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID and
the corresponding input data for that node.
"""
graph.validate_graph(for_run=True)
_validate_node_input_credentials(graph, user_id, node_credentials_input_map)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = node.input_default.get("name")
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
if (
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
and node.webhook_id
):
if webhook_payload_key not in graph_inputs:
raise ValueError(
f"Node {block.name} #{node.id} webhook payload is missing"
)
input_data = {"payload": graph_inputs[webhook_payload_key]}
# Apply node credentials overrides
if node_credentials_input_map and (
node_credentials := node_credentials_input_map.get(node.id)
):
input_data.update({k: v.model_dump() for k, v in node_credentials.items()})
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input
# ============ Execution Queue Helpers ============ #
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
name="graph_execution_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=True,
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
def create_execution_queue_config() -> RabbitMQConfig:
"""
Define two exchanges and queues:
- 'graph_execution' (DIRECT) for run tasks.
- 'graph_execution_cancel' (FANOUT) for cancel requests.
"""
run_queue = Queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
exchange=GRAPH_EXECUTION_EXCHANGE,
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
)
cancel_queue = Queue(
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
async def add_graph_execution_async(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
""" # noqa
graph: GraphModel | None = await get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = await create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = await get_async_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
await queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_async_execution_event_bus()
await bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
await update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise
def add_graph_execution(
graph_id: str,
user_id: str,
inputs: BlockInput,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[dict[str, CredentialsMetaInput]] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
"""
db = get_db_client()
graph: GraphModel | None = db.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
if not graph:
raise NotFoundError(f"Graph #{graph_id} not found.")
node_credentials_input_map = (
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else None
)
graph_exec = db.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
starting_nodes_input=construct_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=inputs,
node_credentials_input_map=node_credentials_input_map,
),
preset_id=preset_id,
)
try:
queue = get_execution_queue()
graph_exec_entry = graph_exec.to_graph_execution_entry()
if node_credentials_input_map:
graph_exec_entry.node_credentials_input_map = node_credentials_input_map
queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
bus = get_execution_event_bus()
bus.publish(graph_exec)
return graph_exec
except Exception as e:
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=str(e)),
)
raise

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
if TYPE_CHECKING:
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
from autogpt_libs.utils.cache import thread_cached
from autogpt_libs.utils.synchronize import RedisKeyedMutex
@@ -161,6 +161,14 @@ smartlead_credentials = APIKeyCredentials(
expires_at=None,
)
google_maps_credentials = APIKeyCredentials(
id="9aa1bde0-4947-4a70-a20c-84daa3850d52",
provider="google_maps",
api_key=SecretStr(settings.secrets.google_maps_api_key),
title="Use Credits for Google Maps",
expires_at=None,
)
zerobounce_credentials = APIKeyCredentials(
id="63a6e279-2dc2-448e-bf57-85776f7176dc",
provider="zerobounce",
@@ -190,6 +198,7 @@ DEFAULT_CREDENTIALS = [
apollo_credentials,
smartlead_credentials,
zerobounce_credentials,
google_maps_credentials,
]
@@ -201,11 +210,11 @@ class IntegrationCredentialsStore:
@property
@thread_cached
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
def db_manager(self) -> "DatabaseManagerClient":
from backend.executor.database import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_integrations(user_id):
@@ -263,6 +272,8 @@ class IntegrationCredentialsStore:
all_credentials.append(smartlead_credentials)
if settings.secrets.zerobounce_api_key:
all_credentials.append(zerobounce_credentials)
if settings.secrets.google_maps_api_key:
all_credentials.append(google_maps_credentials)
return all_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:

View File

@@ -93,7 +93,7 @@ class IntegrationCredentialsManager:
fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock and _lock.locked():
if _lock and _lock.locked() and _lock.owned():
_lock.release()
credentials = fresh_credentials
@@ -145,7 +145,7 @@ class IntegrationCredentialsManager:
try:
yield
finally:
if lock.locked():
if lock.locked() and lock.owned():
lock.release()
def release_all_locks(self):

View File

@@ -29,6 +29,7 @@ class ProviderName(str, Enum):
OPENWEATHERMAP = "openweathermap"
OPEN_ROUTER = "open_router"
PINECONE = "pinecone"
POSTGRES = "postgres"
REDDIT = "reddit"
REPLICATE = "replicate"
REVID = "revid"

View File

@@ -1,5 +1,6 @@
from .notifications import NotificationManager
from .notifications import NotificationManager, NotificationManagerClient
__all__ = [
"NotificationManager",
"NotificationManagerClient",
]

View File

@@ -31,7 +31,13 @@ from backend.data.notifications import (
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -108,16 +114,16 @@ def create_notification_config() -> RabbitMQConfig:
@thread_cached
def get_scheduler():
from backend.executor import Scheduler
from backend.executor.scheduler import SchedulerClient
return get_service_client(Scheduler)
return get_service_client(SchedulerClient)
@thread_cached
def get_db():
from backend.executor.database import DatabaseManager
from backend.executor.database import DatabaseManagerClient
return get_service_client(DatabaseManager)
return get_service_client(DatabaseManagerClient)
class NotificationManager(AppService):
@@ -774,3 +780,14 @@ class NotificationManager(AppService):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
class NotificationManagerClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return NotificationManager
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
queue_notification = NotificationManager.queue_notification
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary

View File

@@ -2,7 +2,6 @@ import logging
from collections import defaultdict
from typing import Annotated, Any, Dict, List, Optional, Sequence
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict
@@ -13,17 +12,10 @@ from backend.data import graph as graph_db
from backend.data.api_key import APIKey
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.execution import NodeExecutionResult
from backend.executor import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.server.external.middleware import require_permission
from backend.util.service import get_service_client
from backend.util.settings import Settings
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
settings = Settings()
logger = logging.getLogger(__name__)
@@ -98,20 +90,20 @@ def execute_graph_block(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],
)
def execute_graph(
async def execute_graph(
graph_id: str,
graph_version: int,
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
graph_exec = execution_manager_client().add_execution(
graph_id,
graph_version=graph_version,
data=node_input,
graph_exec = await add_graph_execution_async(
graph_id=graph_id,
user_id=api_key.user_id,
inputs=node_input,
graph_version=graph_version,
)
return {"id": graph_exec.graph_exec_id}
return {"id": graph_exec.id}
except Exception as e:
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@@ -130,7 +122,7 @@ async def get_graph_execution_results(
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
results = await execution_db.get_node_execution_results(graph_exec_id)
results = await execution_db.get_node_executions(graph_exec_id)
last_result = results[-1] if results else None
execution_status = (
last_result.status if last_result else AgentExecutionStatus.INCOMPLETE

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Awaitable, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field
@@ -14,13 +15,12 @@ from backend.data.integrations import (
wait_for_webhook_event,
)
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
from backend.executor.manager import ExecutionManager
from backend.executor.utils import add_graph_execution_async
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
from backend.util.exceptions import NeedConfirmation, NotFoundError
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
@@ -309,19 +309,22 @@ async def webhook_ingress_generic(
if not webhook.attached_nodes:
return
executor = get_service_client(ExecutionManager)
executions: list[Awaitable] = []
for node in webhook.attached_nodes:
logger.debug(f"Webhook-attached node: {node}")
if not node.is_triggered_by_event_type(event_type):
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
executions.append(
add_graph_execution_async(
user_id=webhook.user_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
inputs={f"webhook_{webhook_id}_payload": payload},
)
)
asyncio.gather(*executions)
@router.post("/webhooks/{webhook_id}/ping")

View File

@@ -17,9 +17,9 @@ import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.integrations.router
import backend.server.routers.postmark.postmark
import backend.server.routers.v1
import backend.server.v2.admin.credit_admin_routes
import backend.server.v2.admin.store_admin_routes
import backend.server.v2.library.db
import backend.server.v2.library.model
@@ -29,6 +29,7 @@ import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
from backend.blocks.llm import LlmModel
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.server.external.api import external_app
@@ -57,8 +58,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
# FIXME ERROR: operator does not exist: text ? unknown
# await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
with launch_darkly_context():
yield
await backend.data.db.disconnect()
@@ -108,6 +108,11 @@ app.include_router(
tags=["v2", "admin"],
prefix="/api/store",
)
app.include_router(
backend.server.v2.admin.credit_admin_routes.router,
tags=["v2", "admin"],
prefix="/api/credits",
)
app.include_router(
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
)
@@ -156,11 +161,12 @@ class AgentServer(backend.util.service.AppProcess):
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
return await backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input or {},
inputs=node_input or {},
credentials_inputs={},
)
@staticmethod
@@ -201,7 +207,7 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.list_presets(
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)
@@ -213,7 +219,7 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.LibraryAgentPresetCreatable,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
@@ -223,7 +229,7 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.LibraryAgentPresetUpdatable,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
@@ -275,7 +281,9 @@ class AgentServer(backend.util.service.AppProcess):
provider: ProviderName,
credentials: Credentials,
) -> Credentials:
return backend.server.integrations.router.create_credentials(
from backend.server.integrations.router import create_credentials
return create_credentials(
user_id=user_id, provider=provider, credentials=credentials
)

View File

@@ -13,7 +13,6 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
@@ -31,7 +30,7 @@ from backend.data.api_key import (
suspend_api_key,
update_api_key_permissions,
)
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -41,6 +40,8 @@ from backend.data.credit import (
get_user_credit_model,
set_auto_top_up,
)
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.model import CredentialsMetaInput
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.data.onboarding import (
UserOnboardingUpdate,
@@ -49,13 +50,16 @@ from backend.data.onboarding import (
onboarding_enabled,
update_user_onboarding,
)
from backend.data.rabbitmq import AsyncRabbitMQ
from backend.data.user import (
get_or_create_user,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
)
from backend.executor import ExecutionManager, Scheduler, scheduler
from backend.executor import scheduler
from backend.executor import utils as execution_utils
from backend.executor.utils import create_execution_queue_config
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import (
on_graph_activate,
@@ -79,13 +83,20 @@ if TYPE_CHECKING:
@thread_cached
def execution_manager_client() -> ExecutionManager:
return get_service_client(ExecutionManager)
def execution_scheduler_client() -> scheduler.SchedulerClient:
return get_service_client(scheduler.SchedulerClient)
@thread_cached
def execution_scheduler_client() -> Scheduler:
return get_service_client(Scheduler)
async def execution_queue_client() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_execution_queue_config())
await client.connect()
return client
@thread_cached
def execution_event_bus() -> AsyncRedisExecutionEventBus:
return AsyncRedisExecutionEventBus()
settings = Settings()
@@ -206,7 +217,7 @@ async def is_onboarding_enabled():
@v1_router.get(path="/blocks", tags=["blocks"], dependencies=[Depends(auth_middleware)])
def get_graph_blocks() -> Sequence[dict[Any, Any]]:
blocks = [block() for block in backend.data.block.get_blocks().values()]
blocks = [block() for block in get_blocks().values()]
costs = get_block_costs()
return [
{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks if not b.disabled
@@ -219,7 +230,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
dependencies=[Depends(auth_middleware)],
)
def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlockOutput:
obj = backend.data.block.get_block(block_id)
obj = get_block(block_id)
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
@@ -308,7 +319,7 @@ async def configure_user_auto_top_up(
dependencies=[Depends(auth_middleware)],
)
async def get_user_auto_top_up(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> AutoTopUpConfig:
return await get_auto_top_up(user_id)
@@ -375,7 +386,7 @@ async def get_credit_history(
@v1_router.get(path="/credits/refunds", dependencies=[Depends(auth_middleware)])
async def get_refund_requests(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[RefundRequest]:
return await _user_credit_model.get_refund_requests(user_id)
@@ -391,7 +402,7 @@ class DeleteGraphResponse(TypedDict):
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
@@ -580,16 +591,25 @@ async def set_graph_active_version(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
def execute_graph(
async def execute_graph(
graph_id: str,
node_input: Annotated[dict[str, Any], Body(..., default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
credentials_inputs: Annotated[
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
],
graph_version: Optional[int] = None,
preset_id: Optional[str] = None,
) -> ExecuteGraphResponse:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id, graph_version=graph_version
graph_exec = await execution_utils.add_graph_execution_async(
graph_id=graph_id,
user_id=user_id,
inputs=inputs,
preset_id=preset_id,
graph_version=graph_version,
graph_credentials_inputs=credentials_inputs,
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
return ExecuteGraphResponse(graph_exec_id=graph_exec.id)
@v1_router.post(
@@ -605,9 +625,7 @@ async def stop_graph_run(
):
raise HTTPException(404, detail=f"Agent execution #{graph_exec_id} not found")
await asyncio.to_thread(
lambda: execution_manager_client().cancel_execution(graph_exec_id)
)
await _cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
result = await execution_db.get_graph_execution(
@@ -621,6 +639,49 @@ async def stop_graph_run(
return result
async def _cancel_execution(graph_exec_id: str):
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
queue_client = await execution_queue_client()
await queue_client.publish_message(
routing_key="",
message=execution_utils.CancelExecutionEvent(
graph_exec_id=graph_exec_id
).model_dump_json(),
exchange=execution_utils.GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
# Update the status of the graph & node executions
await execution_db.update_graph_execution_stats(
graph_exec_id,
execution_db.ExecutionStatus.TERMINATED,
)
node_execs = [
node_exec.model_copy(update={"status": execution_db.ExecutionStatus.TERMINATED})
for node_exec in await execution_db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
execution_db.ExecutionStatus.QUEUED,
execution_db.ExecutionStatus.RUNNING,
execution_db.ExecutionStatus.INCOMPLETE,
],
)
]
await execution_db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
execution_db.ExecutionStatus.TERMINATED,
)
await asyncio.gather(
*[execution_event_bus().publish(node_exec) for node_exec in node_execs]
)
@v1_router.get(
path="/executions",
tags=["graphs"],
@@ -718,14 +779,12 @@ async def create_schedule(
detail=f"Graph #{schedule.graph_id} v.{schedule.graph_version} not found.",
)
return await asyncio.to_thread(
lambda: execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
return await execution_scheduler_client().add_execution_schedule(
graph_id=schedule.graph_id,
graph_version=graph.version,
cron=schedule.cron,
input_data=schedule.input_data,
user_id=user_id,
)
@@ -734,11 +793,11 @@ async def create_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def delete_schedule(
async def delete_schedule(
schedule_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
await execution_scheduler_client().delete_schedule(schedule_id, user_id=user_id)
return {"id": schedule_id}
@@ -747,11 +806,11 @@ def delete_schedule(
tags=["schedules"],
dependencies=[Depends(auth_middleware)],
)
def get_execution_schedules(
async def get_execution_schedules(
user_id: Annotated[str, Depends(get_user_id)],
graph_id: str | None = None,
) -> list[scheduler.ExecutionJobInfo]:
return execution_scheduler_client().get_execution_schedules(
return await execution_scheduler_client().get_execution_schedules(
user_id=user_id,
graph_id=graph_id,
)
@@ -792,7 +851,7 @@ async def create_api_key(
dependencies=[Depends(auth_middleware)],
)
async def get_api_keys(
user_id: Annotated[str, Depends(get_user_id)]
user_id: Annotated[str, Depends(get_user_id)],
) -> list[APIKeyWithoutHash]:
"""List all API keys for the user"""
try:

View File

@@ -0,0 +1,77 @@
import logging
import typing
from autogpt_libs.auth import requires_admin_user
from autogpt_libs.auth.depends import get_user_id
from fastapi import APIRouter, Body, Depends
from prisma import Json
from prisma.enums import CreditTransactionType
from backend.data.credit import admin_get_user_history, get_user_credit_model
from backend.server.v2.admin.model import AddUserCreditsResponse, UserHistoryResponse
logger = logging.getLogger(__name__)
_user_credit_model = get_user_credit_model()
router = APIRouter(
prefix="/admin",
tags=["credits", "admin"],
dependencies=[Depends(requires_admin_user)],
)
@router.post("/add_credits", response_model=AddUserCreditsResponse)
async def add_user_credits(
user_id: typing.Annotated[str, Body()],
amount: typing.Annotated[int, Body()],
comments: typing.Annotated[str, Body()],
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
):
""" """
logger.info(f"Admin user {admin_user} is adding {amount} credits to user {user_id}")
new_balance, transaction_key = await _user_credit_model._add_transaction(
user_id,
amount,
transaction_type=CreditTransactionType.GRANT,
metadata=Json({"admin_id": admin_user, "reason": comments}),
)
return {
"new_balance": new_balance,
"transaction_key": transaction_key,
}
@router.get(
"/users_history",
response_model=UserHistoryResponse,
)
async def admin_get_all_user_history(
admin_user: typing.Annotated[
str,
Depends(get_user_id),
],
search: typing.Optional[str] = None,
page: int = 1,
page_size: int = 20,
transaction_filter: typing.Optional[CreditTransactionType] = None,
):
""" """
logger.info(f"Admin user {admin_user} is getting grant history")
try:
resp = await admin_get_user_history(
page=page,
page_size=page_size,
search=search,
transaction_filter=transaction_filter,
)
logger.info(f"Admin user {admin_user} got {len(resp.history)} grant history")
return resp
except Exception as e:
logger.exception(f"Error getting grant history: {e}")
raise e

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
from backend.data.model import UserTransaction
from backend.server.model import Pagination
class UserHistoryResponse(BaseModel):
"""Response model for listings with version history"""
history: list[UserTransaction]
pagination: Pagination
class AddUserCreditsResponse(BaseModel):
new_balance: int
transaction_key: str

View File

@@ -1,4 +1,5 @@
import logging
import tempfile
import typing
import autogpt_libs.auth.depends
@@ -9,6 +10,7 @@ import prisma.enums
import backend.server.v2.store.db
import backend.server.v2.store.exceptions
import backend.server.v2.store.model
import backend.util.json
logger = logging.getLogger(__name__)
@@ -98,3 +100,47 @@ async def review_submission(
status_code=500,
content={"detail": "An error occurred while reviewing the submission"},
)
@router.get(
"/submissions/download/{store_listing_version_id}",
tags=["store", "admin"],
dependencies=[fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user)],
)
async def admin_download_agent_file(
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
store_listing_version_id: str = fastapi.Path(
..., description="The ID of the agent to download"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
Args:
store_listing_version_id (str): The ID of the agent to download
Returns:
StreamingResponse: A streaming response containing the agent's graph data.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
graph_data = await backend.server.v2.store.db.get_agent(
user_id=user.user_id,
store_listing_version_id=store_listing_version_id,
)
file_name = f"agent_{graph_data.id}_v{graph_data.version or 'latest'}.json"
# Sending graph as a stream (similar to marketplace v1)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(backend.util.json.dumps(graph_data))
tmp_file.flush()
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)

View File

@@ -13,14 +13,17 @@ import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
import backend.server.v2.store.image_gen as store_image_gen
import backend.server.v2.store.media as store_media
from backend.data import db
from backend.data import graph as graph_db
from backend.data.db import locked_transaction
from backend.data.execution import get_graph_execution
from backend.data.includes import library_agent_include
from backend.util.exceptions import NotFoundError
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Config
logger = logging.getLogger(__name__)
config = Config()
integration_creds_manager = IntegrationCredentialsManager()
async def list_library_agents(
@@ -142,7 +145,7 @@ async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent
Get a specific agent from the user's library.
Args:
id: ID of the library agent to retrieve.
library_agent_id: ID of the library agent to retrieve.
user_id: ID of the authenticated user.
Returns:
@@ -208,7 +211,7 @@ async def add_generated_agent_image(
async def create_library_agent(
graph: backend.data.graph.GraphModel,
user_id: str,
) -> prisma.models.LibraryAgent:
) -> library_model.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table).
@@ -229,19 +232,21 @@ async def create_library_agent(
)
try:
return await prisma.models.LibraryAgent.prisma().create(
agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
isCreatedByUser=(user_id == graph.user_id),
useGraphIsActiveVersion=True,
User={"connect": {"id": user_id}},
# Creator={"connect": {"id": graph.user_id}},
# Creator={"connect": {"id": agent.userId}},
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
)
),
include={"AgentGraph": True},
)
return library_model.LibraryAgent.from_db(agent)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating agent in library: {e}")
raise store_exceptions.DatabaseError("Failed to create agent in library") from e
@@ -400,13 +405,11 @@ async def add_store_agent_to_library(
# Check if user already has this agent
existing_library_agent = (
await prisma.models.LibraryAgent.prisma().find_unique(
await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId_agentGraphId_agentGraphVersion": {
"userId": user_id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
}
"userId": user_id,
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
},
include=library_agent_include(user_id),
)
@@ -414,13 +417,13 @@ async def add_store_agent_to_library(
if existing_library_agent:
if existing_library_agent.isDeleted:
# Even if agent exists it needs to be marked as not deleted
await update_library_agent(
existing_library_agent.id, user_id, is_deleted=False
await set_is_deleted_for_library_agent(
user_id, graph.id, graph.version, False
)
else:
logger.debug(
f"User #{user_id} already has graph #{graph.id} "
f"v{graph.version} in their library"
"in their library"
)
return library_model.LibraryAgent.from_db(existing_library_agent)
@@ -428,11 +431,8 @@ async def add_store_agent_to_library(
added_agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
AgentGraph={
"connect": {
"graphVersionId": {"id": graph.id, "version": graph.version}
}
},
agentGraphId=graph.id,
agentGraphVersion=graph.version,
isCreatedByUser=False,
),
include=library_agent_include(user_id),
@@ -452,22 +452,60 @@ async def add_store_agent_to_library(
raise store_exceptions.DatabaseError("Failed to add agent to library") from e
async def set_is_deleted_for_library_agent(
user_id: str, agent_id: str, agent_version: int, is_deleted: bool
) -> None:
"""
Changes the isDeleted flag for a library agent.
Args:
user_id: The user's library from which the agent is being removed.
agent_id: The ID of the agent to remove.
agent_version: The version of the agent to remove.
is_deleted: Whether the agent is being marked as deleted.
Raises:
DatabaseError: If there's an issue updating the Library
"""
logger.debug(
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} "
f"in library for user {user_id}"
)
try:
logger.warning(
f"Setting isDeleted={is_deleted} for agent {agent_id} v{agent_version} in library for user {user_id}"
)
count = await prisma.models.LibraryAgent.prisma().update_many(
where={
"userId": user_id,
"agentGraphId": agent_id,
"agentGraphVersion": agent_version,
},
data={"isDeleted": is_deleted},
)
logger.warning(f"Updated {count} isDeleted library agents")
except prisma.errors.PrismaError as e:
logger.error(f"Database error setting agent isDeleted: {e}")
raise store_exceptions.DatabaseError(
"Failed to set agent isDeleted in library"
) from e
##############################################
########### Presets DB Functions #############
##############################################
async def list_presets(
user_id: str, page: int, page_size: int, graph_id: Optional[str] = None
async def get_presets(
user_id: str, page: int, page_size: int
) -> library_model.LibraryAgentPresetResponse:
"""
Retrieves a paginated list of AgentPresets for the specified user.
Args:
user_id: The user ID whose presets are being retrieved.
page: The current page index (1-based).
page: The current page index (0-based or 1-based, clarify in your domain).
page_size: Number of items to retrieve per page.
graph_id: Agent Graph ID to filter by.
Returns:
A LibraryAgentPresetResponse containing a list of presets and pagination info.
@@ -479,24 +517,21 @@ async def list_presets(
f"Fetching presets for user #{user_id}, page={page}, page_size={page_size}"
)
if page < 1 or page_size < 1:
if page < 0 or page_size < 1:
logger.warning(
"Invalid pagination input: page=%d, page_size=%d", page, page_size
)
raise store_exceptions.DatabaseError("Invalid pagination parameters")
query_filter: prisma.types.AgentPresetWhereInput = {"userId": user_id}
if graph_id:
query_filter["agentGraphId"] = graph_id
try:
presets_records = await prisma.models.AgentPreset.prisma().find_many(
where=query_filter,
skip=(page - 1) * page_size,
where={"userId": user_id},
skip=page * page_size,
take=page_size,
include={"InputPresets": True},
)
total_items = await prisma.models.AgentPreset.prisma().count(where=query_filter)
total_items = await prisma.models.AgentPreset.prisma().count(
where={"userId": user_id}
)
total_pages = (total_items + page_size - 1) // page_size
presets = [
@@ -549,142 +584,69 @@ async def get_preset(
raise store_exceptions.DatabaseError("Failed to fetch preset") from e
async def create_preset(
async def upsert_preset(
user_id: str,
preset: library_model.LibraryAgentPresetCreatable,
preset: library_model.CreateLibraryAgentPresetRequest,
preset_id: Optional[str] = None,
) -> library_model.LibraryAgentPreset:
"""
Creates a new AgentPreset for a user.
Creates or updates an AgentPreset for a user.
Args:
user_id: The ID of the user creating the preset.
preset: The preset data used for creation.
user_id: The ID of the user creating/updating the preset.
preset: The preset data used for creation or update.
preset_id: An optional preset ID to update; if None, a new preset is created.
Returns:
The newly created LibraryAgentPreset.
The newly created or updated LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in creating the preset.
"""
logger.debug(
f"Creating preset ({repr(preset.name)}) for user #{user_id}",
)
try:
new_preset = await prisma.models.AgentPreset.prisma().create(
data=prisma.types.AgentPresetCreateInput(
userId=user_id,
name=preset.name,
description=preset.description,
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
InputPresets={
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
]
},
),
include={"InputPresets": True},
)
return library_model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating preset: {e}")
raise store_exceptions.DatabaseError("Failed to create preset") from e
async def create_preset_from_graph_execution(
user_id: str,
create_request: library_model.LibraryAgentPresetCreatableFromGraphExecution,
) -> library_model.LibraryAgentPreset:
"""
Creates a new AgentPreset from an AgentGraphExecution.
Params:
user_id: The ID of the user creating the preset.
create_request: The data used for creation.
Returns:
The newly created LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in creating the preset.
"""
graph_exec_id = create_request.graph_execution_id
graph_execution = await get_graph_execution(user_id, graph_exec_id)
if not graph_execution:
raise NotFoundError(f"Graph execution #{graph_exec_id} not found")
logger.debug(
f"Creating preset for user #{user_id} from graph execution #{graph_exec_id}",
)
return await create_preset(
user_id=user_id,
preset=library_model.LibraryAgentPresetCreatable(
inputs=graph_execution.inputs,
graph_id=graph_execution.graph_id,
graph_version=graph_execution.graph_version,
name=create_request.name,
description=create_request.description,
is_active=create_request.is_active,
),
)
async def update_preset(
user_id: str,
preset_id: str,
preset: library_model.LibraryAgentPresetUpdatable,
) -> library_model.LibraryAgentPreset:
"""
Updates an existing AgentPreset for a user.
Args:
user_id: The ID of the user updating the preset.
preset_id: The ID of the preset to update.
preset: The preset data used for the update.
Returns:
The updated LibraryAgentPreset.
Raises:
DatabaseError: If there's a database error in updating the preset.
DatabaseError: If there's a database error in creating or updating the preset.
ValueError: If attempting to update a non-existent preset.
"""
logger.debug(
f"Updating preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
f"Upserting preset #{preset_id} ({repr(preset.name)}) for user #{user_id}",
)
try:
update_data: prisma.types.AgentPresetUpdateInput = {}
if preset.name:
update_data["name"] = preset.name
if preset.description:
update_data["description"] = preset.description
if preset.inputs:
update_data["InputPresets"] = {
"create": [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput( # noqa
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
]
}
if preset.is_active:
update_data["isActive"] = preset.is_active
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data=update_data,
include={"InputPresets": True},
)
if not updated:
raise ValueError(f"AgentPreset #{preset_id} not found")
return library_model.LibraryAgentPreset.from_db(updated)
inputs = [
prisma.types.AgentNodeExecutionInputOutputCreateWithoutRelationsInput(
name=name, data=prisma.fields.Json(data)
)
for name, data in preset.inputs.items()
]
if preset_id:
# Update existing preset
updated = await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id},
data={
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
"InputPresets": {"create": inputs},
},
include={"InputPresets": True},
)
if not updated:
raise ValueError(f"AgentPreset #{preset_id} not found")
return library_model.LibraryAgentPreset.from_db(updated)
else:
# Create new preset
new_preset = await prisma.models.AgentPreset.prisma().create(
data=prisma.types.AgentPresetCreateInput(
userId=user_id,
name=preset.name,
description=preset.description,
agentGraphId=preset.graph_id,
agentGraphVersion=preset.graph_version,
isActive=preset.is_active,
InputPresets={"create": inputs},
),
include={"InputPresets": True},
)
return library_model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating preset: {e}")
raise store_exceptions.DatabaseError("Failed to update preset") from e
logger.error(f"Database error upserting preset: {e}")
raise store_exceptions.DatabaseError("Failed to create preset") from e
async def delete_preset(user_id: str, preset_id: str) -> None:
@@ -707,3 +669,47 @@ async def delete_preset(user_id: str, preset_id: str) -> None:
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {e}")
raise store_exceptions.DatabaseError("Failed to delete preset") from e
async def fork_library_agent(library_agent_id: str, user_id: str):
"""
Clones a library agent and its underyling graph and nodes (with new ids) for the given user.
Args:
library_agent_id: The ID of the library agent to fork.
user_id: The ID of the user who owns the library agent.
Returns:
The forked LibraryAgent.
Raises:
DatabaseError: If there's an error during the forking process.
"""
logger.debug(f"Forking library agent {library_agent_id} for user {user_id}")
try:
async with db.locked_transaction(f"usr_trx_{user_id}-fork_agent"):
# Fetch the original agent
original_agent = await get_library_agent(library_agent_id, user_id)
# Check if user owns the library agent
# TODO: once we have open/closed sourced agents this needs to be enabled ~kcze
# + update library/agents/[id]/page.tsx agent actions
# if not original_agent.can_access_graph:
# raise store_exceptions.DatabaseError(
# f"User {user_id} cannot access library agent graph {library_agent_id}"
# )
# Fork the underlying graph and nodes
new_graph = await graph_db.fork_graph(
original_agent.graph_id, original_agent.graph_version, user_id
)
new_graph = await on_graph_activate(
new_graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
# Create a library agent for the new graph
return await create_library_agent(new_graph, user_id)
except prisma.errors.PrismaError as e:
logger.error(f"Database error cloning library agent: {e}")
raise store_exceptions.DatabaseError("Failed to fork library agent") from e

View File

@@ -168,62 +168,27 @@ class LibraryAgentResponse(pydantic.BaseModel):
pagination: server_model.Pagination
class LibraryAgentPresetCreatable(pydantic.BaseModel):
"""
Request model used when creating a new preset for a library agent.
"""
graph_id: str
graph_version: int
inputs: block_model.BlockInput
name: str
description: str
is_active: bool = True
class LibraryAgentPresetCreatableFromGraphExecution(pydantic.BaseModel):
"""
Request model used when creating a new preset for a library agent.
"""
graph_execution_id: str
name: str
description: str
is_active: bool = True
class LibraryAgentPresetUpdatable(pydantic.BaseModel):
"""
Request model used when updating a preset for a library agent.
"""
inputs: Optional[block_model.BlockInput] = None
name: Optional[str] = None
description: Optional[str] = None
is_active: Optional[bool] = None
class LibraryAgentPreset(LibraryAgentPresetCreatable):
class LibraryAgentPreset(pydantic.BaseModel):
"""Represents a preset configuration for a library agent."""
id: str
updated_at: datetime.datetime
graph_id: str
graph_version: int
name: str
description: str
is_active: bool
inputs: block_model.BlockInput
@classmethod
def from_db(cls, preset: prisma.models.AgentPreset) -> "LibraryAgentPreset":
if preset.InputPresets is None:
raise ValueError("Input values must be included in object")
input_data: block_model.BlockInput = {}
for preset_input in preset.InputPresets:
for preset_input in preset.InputPresets or []:
input_data[preset_input.name] = preset_input.data
return cls(
@@ -245,6 +210,19 @@ class LibraryAgentPresetResponse(pydantic.BaseModel):
pagination: server_model.Pagination
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
"""
Request model used when creating a new preset for a library agent.
"""
name: str
description: str
inputs: block_model.BlockInput
graph_id: str
graph_version: int
is_active: bool
class LibraryAgentFilter(str, Enum):
"""Possible filters for searching library agents."""

View File

@@ -22,11 +22,13 @@ async def test_agent_preset_from_db():
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data=prisma.Json({"type": "string", "value": "test value"}),
prisma.models.AgentNodeExecutionInputOutput.model_validate(
{
"id": "input-123",
"time": datetime.datetime.now(),
"name": "input1",
"data": '{"type": "string", "value": "test value"}',
}
)
],
)

View File

@@ -190,3 +190,14 @@ async def update_library_agent(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update library agent",
) from e
@router.post("/{library_agent_id}/fork")
async def fork_library_agent(
library_agent_id: str,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> library_model.LibraryAgent:
return await library_db.fork_library_agent(
library_agent_id=library_agent_id,
user_id=user_id,
)

View File

@@ -1,39 +1,27 @@
import logging
from typing import Annotated, Any, Optional
from typing import Annotated, Any
import autogpt_libs.auth as autogpt_auth_lib
import autogpt_libs.utils.cache
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi import APIRouter, Body, Depends, HTTPException, status
import backend.executor
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
import backend.util.service
from backend.util.exceptions import NotFoundError
from backend.executor.utils import add_graph_execution_async
logger = logging.getLogger(__name__)
router = APIRouter()
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
"""Return a cached instance of ExecutionManager client."""
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get(
"/presets",
summary="List presets",
description="Retrieve a paginated list of presets for the current user.",
)
async def list_presets(
async def get_presets(
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
page: int = Query(default=1, ge=1),
page_size: int = Query(default=10, ge=1),
graph_id: Optional[str] = Query(
description="Allows to filter presets by a specific agent graph"
),
page: int = 1,
page_size: int = 10,
) -> models.LibraryAgentPresetResponse:
"""
Retrieve a paginated list of presets for the current user.
@@ -42,18 +30,12 @@ async def list_presets(
user_id (str): ID of the authenticated user.
page (int): Page number for pagination.
page_size (int): Number of items per page.
graph_id: Allows to filter presets by a specific agent graph.
Returns:
models.LibraryAgentPresetResponse: A response containing the list of presets.
"""
try:
return await db.list_presets(
user_id=user_id,
graph_id=graph_id,
page=page,
page_size=page_size,
)
return await db.get_presets(user_id, page, page_size)
except Exception as e:
logger.exception(f"Exception occurred while getting presets: {e}")
raise HTTPException(
@@ -106,17 +88,14 @@ async def get_preset(
description="Create a new preset for the current user.",
)
async def create_preset(
preset: (
models.LibraryAgentPresetCreatable
| models.LibraryAgentPresetCreatableFromGraphExecution
),
preset: models.CreateLibraryAgentPresetRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Create a new library agent preset. Automatically corrects node_input format if needed.
Args:
preset (models.LibraryAgentPresetCreatable): The preset data to create.
preset (models.CreateLibraryAgentPresetRequest): The preset data to create.
user_id (str): ID of the authenticated user.
Returns:
@@ -126,12 +105,7 @@ async def create_preset(
HTTPException: If an error occurs while creating the preset.
"""
try:
if isinstance(preset, models.LibraryAgentPresetCreatable):
return await db.create_preset(user_id, preset)
else:
return await db.create_preset_from_graph_execution(user_id, preset)
except NotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
return await db.upsert_preset(user_id, preset)
except Exception as e:
logger.exception(f"Exception occurred while creating preset: {e}")
raise HTTPException(
@@ -140,22 +114,22 @@ async def create_preset(
)
@router.patch(
@router.put(
"/presets/{preset_id}",
summary="Update an existing preset",
description="Update an existing preset by its ID.",
)
async def update_preset(
preset_id: str,
preset: models.LibraryAgentPresetUpdatable,
preset: models.CreateLibraryAgentPresetRequest,
user_id: str = Depends(autogpt_auth_lib.depends.get_user_id),
) -> models.LibraryAgentPreset:
"""
Update an existing library agent preset.
Update an existing library agent preset. If the preset doesn't exist, it may be created.
Args:
preset_id (str): ID of the preset to update.
preset (models.LibraryAgentPresetUpdatable): The preset data to update.
preset (models.CreateLibraryAgentPresetRequest): The preset data to update.
user_id (str): ID of the authenticated user.
Returns:
@@ -165,9 +139,7 @@ async def update_preset(
HTTPException: If an error occurs while updating the preset.
"""
try:
return await db.update_preset(
user_id=user_id, preset_id=preset_id, preset=preset
)
return await db.upsert_preset(user_id, preset, preset_id)
except Exception as e:
logger.exception(f"Exception occurred whilst updating preset: {e}")
raise HTTPException(
@@ -246,17 +218,17 @@ async def execute_preset(
# Merge input overrides with preset inputs
merged_node_input = preset.inputs | node_input
execution = execution_manager_client().add_execution(
execution = await add_graph_execution_async(
graph_id=graph_id,
graph_version=graph_version,
data=merged_node_input,
user_id=user_id,
inputs=merged_node_input,
preset_id=preset_id,
graph_version=graph_version,
)
logger.debug(f"Execution added: {execution} with input: {merged_node_input}")
return {"id": execution.graph_exec_id}
return {"id": execution.id}
except HTTPException:
raise
except Exception as e:

View File

@@ -793,6 +793,7 @@ async def create_store_version(
changes_summary=changes_summary,
version=next_version,
)
except prisma.errors.PrismaError as e:
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create new store version"
@@ -1361,3 +1362,31 @@ async def get_admin_listings_with_versions(
page_size=page_size,
),
)
async def get_agent_as_admin(
user_id: str | None,
store_listing_version_id: str,
) -> GraphModel:
"""Get agent using the version ID and store listing version ID."""
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}
)
)
if not store_listing_version:
raise ValueError(f"Store listing version {store_listing_version_id} not found")
graph = await backend.data.graph.get_graph_as_admin(
user_id=user_id,
graph_id=store_listing_version.agentGraphId,
version=store_listing_version.agentGraphVersion,
for_export=True,
)
if not graph:
raise ValueError(
f"Agent {store_listing_version.agentGraphId} v{store_listing_version.agentGraphVersion} not found"
)
return graph

View File

@@ -4,20 +4,7 @@ from typing import List
import prisma.enums
import pydantic
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[97]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
from backend.server.model import Pagination
class MyAgent(pydantic.BaseModel):

View File

@@ -6,7 +6,6 @@ from typing import Protocol
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.logging.utils import generate_uvicorn_config
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -19,7 +18,7 @@ from backend.server.model import (
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.util.service import AppProcess, get_service_client
from backend.util.service import AppProcess
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -46,13 +45,6 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
event_queue = AsyncRedisExecutionEventBus()

View File

@@ -15,21 +15,25 @@ def to_dict(data) -> dict:
def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
return json.dumps(to_dict(data))
T = TypeVar("T")
@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
def loads(data: str | bytes, *args, target_type: Type[T], **kwargs) -> T: ...
@overload
def loads(data: str, *args, **kwargs) -> Any: ...
def loads(data: str | bytes, *args, **kwargs) -> Any: ...
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
def loads(
data: str | bytes, *args, target_type: Type[T] | None = None, **kwargs
) -> Any:
if isinstance(data, bytes):
data = data.decode("utf-8")
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)

View File

@@ -14,9 +14,7 @@ def sentry_init():
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=f"app:{Settings().config.app_env.value}-behave:{Settings().config.behave_as.value}",
_experiments={
"enable_logs": True,
},
_experiments={"enable_logs": True},
integrations=[
LoggingIntegration(sentry_logs_level=logging.INFO),
AnthropicIntegration(

View File

@@ -3,7 +3,7 @@ import os
import signal
import sys
from abc import ABC, abstractmethod
from multiprocessing import Process, set_start_method
from multiprocessing import Process, get_all_start_methods, set_start_method
from typing import Optional
from backend.util.logging import configure_logging
@@ -30,7 +30,12 @@ class AppProcess(ABC):
process: Optional[Process] = None
cleaned_up = False
set_start_method("spawn", force=True)
if "forkserver" in get_all_start_methods():
set_start_method("forkserver", force=True)
else:
logger.warning("Forkserver start method is not available. Using spawn instead.")
set_start_method("spawn", force=True)
configure_logging()
sentry_init()

View File

@@ -73,3 +73,10 @@ def conn_retry(
return async_wrapper if is_coroutine else sync_wrapper
return decorator
func_retry = retry(
reraise=False,
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=30),
)

View File

@@ -5,8 +5,10 @@ import os
import threading
import time
from abc import ABC, abstractmethod
from functools import cached_property, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
@@ -42,24 +44,15 @@ api_call_timeout = config.rpc_client_call_timeout
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
def expose(func: C) -> C:
func = getattr(func, "__func__", func)
setattr(func, "__exposed__", True)
setattr(func, EXPOSED_FLAG, True)
return func
def exposed_run_and_wait(
f: Callable[P, Coroutine[None, None, R]]
) -> Callable[Concatenate[object, P], R]:
# TODO:
# This function lies about its return type to make the DynamicClient
# call the function synchronously, fix this when DynamicClient can choose
# to call a function synchronously or asynchronously.
return expose(f) # type: ignore
# --------------------------------------------------
# AppService for IPC service based on HTTP request through FastAPI
# --------------------------------------------------
@@ -203,7 +196,7 @@ class AppService(BaseAppService, ABC):
# Register the exposed API routes.
for attr_name, attr in vars(type(self)).items():
if getattr(attr, "__exposed__", False):
if getattr(attr, EXPOSED_FLAG, False):
route_path = f"/{attr_name}"
self.fastapi_app.add_api_route(
route_path,
@@ -234,31 +227,52 @@ class AppService(BaseAppService, ABC):
AS = TypeVar("AS", bound=AppService)
def close_service_client(client: Any) -> None:
if hasattr(client, "close"):
client.close()
else:
logger.warning(f"Client {client} is not closable")
class AppServiceClient(ABC):
@classmethod
@abstractmethod
def get_service_type(cls) -> Type[AppService]:
pass
def health_check(self):
pass
def close(self):
pass
@conn_retry("FastAPI client", "Creating service client", max_retry=api_comm_retry)
ASC = TypeVar("ASC", bound=AppServiceClient)
@conn_retry("AppService client", "Creating service client", max_retry=api_comm_retry)
def get_service_client(
service_type: Type[AS],
service_client_type: Type[ASC],
call_timeout: int | None = api_call_timeout,
) -> AS:
) -> ASC:
class DynamicClient:
def __init__(self):
service_type = service_client_type.get_service_type()
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self.client = httpx.Client(
@cached_property
def sync_client(self) -> httpx.Client:
return httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
)
def _call_method(self, method_name: str, **kwargs) -> Any:
@cached_property
def async_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=self.base_url,
timeout=call_timeout,
)
def _handle_call_method_response(
self, response: httpx.Response, method_name: str
) -> Any:
try:
response = self.client.post(method_name, json=to_dict(kwargs))
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
@@ -269,36 +283,102 @@ def get_service_client(
*(error.args or [str(e)])
)
def _call_method_sync(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
async def _call_method_async(self, method_name: str, **kwargs) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
async def aclose(self):
self.sync_client.close()
await self.async_client.aclose()
def close(self):
self.client.close()
self.sync_client.close()
def _get_params(self, signature: inspect.Signature, *args, **kwargs) -> dict:
if args:
arg_names = list(signature.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
return kwargs
def _get_return(self, expected_return: TypeAdapter | None, result: Any) -> Any:
if expected_return:
return expected_return.validate_python(result)
return result
def __getattr__(self, name: str) -> Callable[..., Any]:
# Try to get the original function from the service type.
orig_func = getattr(service_type, name, None)
if orig_func is None:
raise AttributeError(f"Method {name} not found in {service_type}")
original_func = getattr(service_client_type, name, None)
if original_func is None:
raise AttributeError(
f"Method {name} not found in {service_client_type}"
)
else:
name = original_func.__name__
sig = inspect.signature(orig_func)
sig = inspect.signature(original_func)
ret_ann = sig.return_annotation
if ret_ann != inspect.Signature.empty:
expected_return = TypeAdapter(ret_ann)
else:
expected_return = None
def method(*args, **kwargs) -> Any:
if args:
arg_names = list(sig.parameters.keys())
if arg_names[0] in ("self", "cls"):
arg_names = arg_names[1:]
kwargs.update(dict(zip(arg_names, args)))
result = self._call_method(name, **kwargs)
if expected_return:
return expected_return.validate_python(result)
return result
if inspect.iscoroutinefunction(original_func):
return method
async def async_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = await self._call_method_async(name, **params)
return self._get_return(expected_return, result)
client = cast(AS, DynamicClient())
return async_method
else:
def sync_method(*args, **kwargs) -> Any:
params = self._get_params(sig, *args, **kwargs)
result = self._call_method_sync(name, **params)
return self._get_return(expected_return, result)
return sync_method
client = cast(ASC, DynamicClient())
client.health_check()
return cast(AS, client)
return client
def endpoint_to_sync(
func: Callable[Concatenate[Any, P], Awaitable[R]],
) -> Callable[Concatenate[Any, P], R]:
"""
Produce a *typed* stub that **looks** synchronous to the typechecker.
"""
def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], R], _stub)
def endpoint_to_async(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
"""
The async mirror of `to_sync`.
"""
async def _stub(*args: P.args, **kwargs: P.kwargs) -> R: # pragma: no cover
raise RuntimeError("should be intercepted by __getattr__")
update_wrapper(_stub, func)
return cast(Callable[Concatenate[Any, P], Awaitable[R]], _stub)

View File

@@ -117,6 +117,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=1,
description="Cost per execution in cents after each threshold.",
)
execution_counter_expiration_time: int = Field(
default=60 * 60 * 24,
description="Time in seconds after which the execution counter is reset.",
)
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -0,0 +1,16 @@
-- Modify the OnboardingStep enum
ALTER TYPE "OnboardingStep" ADD VALUE 'GET_RESULTS';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_VISIT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_ADD_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'MARKETPLACE_RUN_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_OPEN';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_SAVE_AGENT';
ALTER TYPE "OnboardingStep" ADD VALUE 'BUILDER_RUN_AGENT';
-- Modify the UserOnboarding table
ALTER TABLE "UserOnboarding"
ADD COLUMN "updatedAt" TIMESTAMP(3),
ADD COLUMN "notificationDot" BOOLEAN NOT NULL DEFAULT true,
ADD COLUMN "notified" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "rewardedFor" "OnboardingStep"[] DEFAULT '{}',
ADD COLUMN "onboardingAgentExecutionId" TEXT

View File

@@ -0,0 +1,56 @@
-- Backfill nulls with empty arrays
UPDATE "UserOnboarding"
SET "integrations" = ARRAY[]::TEXT[]
WHERE "integrations" IS NULL;
UPDATE "UserOnboarding"
SET "completedSteps" = '{}'
WHERE "completedSteps" IS NULL;
UPDATE "UserOnboarding"
SET "notified" = '{}'
WHERE "notified" IS NULL;
UPDATE "UserOnboarding"
SET "rewardedFor" = '{}'
WHERE "rewardedFor" IS NULL;
UPDATE "IntegrationWebhook"
SET "events" = ARRAY[]::TEXT[]
WHERE "events" IS NULL;
UPDATE "APIKey"
SET "permissions" = '{}'
WHERE "permissions" IS NULL;
UPDATE "Profile"
SET "links" = ARRAY[]::TEXT[]
WHERE "links" IS NULL;
UPDATE "StoreListingVersion"
SET "imageUrls" = ARRAY[]::TEXT[]
WHERE "imageUrls" IS NULL;
UPDATE "StoreListingVersion"
SET "categories" = ARRAY[]::TEXT[]
WHERE "categories" IS NULL;
-- Enforce NOT NULL constraints
ALTER TABLE "UserOnboarding"
ALTER COLUMN "integrations" SET NOT NULL,
ALTER COLUMN "completedSteps" SET NOT NULL,
ALTER COLUMN "notified" SET NOT NULL,
ALTER COLUMN "rewardedFor" SET NOT NULL;
ALTER TABLE "IntegrationWebhook"
ALTER COLUMN "events" SET NOT NULL;
ALTER TABLE "APIKey"
ALTER COLUMN "permissions" SET NOT NULL;
ALTER TABLE "Profile"
ALTER COLUMN "links" SET NOT NULL;
ALTER TABLE "StoreListingVersion"
ALTER COLUMN "imageUrls" SET NOT NULL,
ALTER COLUMN "categories" SET NOT NULL;

View File

@@ -0,0 +1,7 @@
-- AlterTable
ALTER TABLE "AgentGraph"
ADD COLUMN "forkedFromId" TEXT,
ADD COLUMN "forkedFromVersion" INTEGER;
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_forkedFromId_forkedFromVersion_fkey" FOREIGN KEY ("forkedFromId", "forkedFromVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -3568,6 +3568,21 @@ files = [
[package.dependencies]
tqdm = "*"
[[package]]
name = "prometheus-client"
version = "0.21.1"
description = "Python client for the Prometheus monitoring system."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"},
{file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"},
]
[package.extras]
twisted = ["twisted"]
[[package]]
name = "propcache"
version = "0.2.1"
@@ -6310,4 +6325,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "781f77ec77cfce78b34fb57063dcc81df8e9c5a4be9a644033a0c197e0063730"
content-hash = "29ccee704d8296c57156daab98bb0cbbf5a43e83526b7f08a14c91fb7a4898f4"

View File

@@ -64,6 +64,7 @@ websockets = "^14.2"
youtube-transcript-api = "^0.6.2"
zerobouncesdk = "^1.1.1"
# NOTE: please insert new dependencies in their alphabetical location
prometheus-client = "^0.21.1"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"

View File

@@ -1,6 +1,7 @@
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
provider = "postgresql"
url = env("DATABASE_URL")
directUrl = env("DIRECT_URL")
}
generator client {
@@ -58,6 +59,7 @@ model User {
}
enum OnboardingStep {
// Introductory onboarding (Library)
WELCOME
USAGE_REASON
INTEGRATIONS
@@ -65,18 +67,32 @@ enum OnboardingStep {
AGENT_NEW_RUN
AGENT_INPUT
CONGRATS
GET_RESULTS
// Marketplace
MARKETPLACE_VISIT
MARKETPLACE_ADD_AGENT
MARKETPLACE_RUN_AGENT
// Builder
BUILDER_OPEN
BUILDER_SAVE_AGENT
BUILDER_RUN_AGENT
}
model UserOnboarding {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
completedSteps OnboardingStep[] @default([])
notificationDot Boolean @default(true)
notified OnboardingStep[] @default([])
rewardedFor OnboardingStep[] @default([])
usageReason String?
integrations String[] @default([])
otherIntegrations String?
selectedStoreListingVersionId String?
agentInput Json?
onboardingAgentExecutionId String?
userId String @unique
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@ -102,6 +118,11 @@ model AgentGraph {
// This allows us to delete user data with deleting the agent which maybe in use by other users
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
forkedFromId String?
forkedFromVersion Int?
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
forks AgentGraph[] @relation("AgentGraphForks")
Nodes AgentNode[]
Executions AgentGraphExecution[]

View File

@@ -6,10 +6,10 @@ from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit
from backend.data.credit import BetaUserCredit, UsageTransactionMetadata
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
from backend.executor.utils import block_usage_cost
from backend.integrations.credentials_store import openai_credentials
from backend.util.test import SpinTestServer
@@ -46,6 +46,7 @@ async def spend_credits(entry: NodeExecutionEntry) -> int:
block_id=entry.block_id,
block=entry.block_id,
input=matching_filter,
reason=f"Ran block {entry.block_id} {block.name}",
),
)

View File

@@ -1,4 +1,4 @@
from backend.data.execution import merge_execution_input, parse_execution_output
from backend.executor.utils import merge_execution_input, parse_execution_output
def test_parse_execution_output():

View File

@@ -357,7 +357,7 @@ async def test_execute_preset(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,
@@ -446,7 +446,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.LibraryAgentPresetCreatable(
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
graph_id=test_graph.id,

View File

@@ -1,7 +1,7 @@
import pytest
from backend.data import db
from backend.executor import Scheduler
from backend.executor.scheduler import SchedulerClient
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
@@ -17,11 +17,11 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(Scheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
scheduler = get_service_client(SchedulerClient)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0
schedule = scheduler.add_execution_schedule(
schedule = await scheduler.add_execution_schedule(
graph_id=test_graph.id,
user_id=test_user.id,
graph_version=1,
@@ -30,10 +30,12 @@ async def test_agent_schedule(server: SpinTestServer):
)
assert schedule
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
schedules = await scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 1
assert schedules[0].cron == "0 0 * * *"
scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = scheduler.get_execution_schedules(test_graph.id, user_id=test_user.id)
await scheduler.delete_schedule(schedule.id, user_id=test_user.id)
schedules = await scheduler.get_execution_schedules(
test_graph.id, user_id=test_user.id
)
assert len(schedules) == 0

View File

@@ -1,6 +1,12 @@
import pytest
from backend.util.service import AppService, expose, get_service_client
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
TEST_SERVICE_PORT = 8765
@@ -32,10 +38,25 @@ class ServiceTest(AppService):
return self.run_and_wait(add_async(a, b))
class ServiceTestClient(AppServiceClient):
@classmethod
def get_service_type(cls):
return ServiceTest
add = ServiceTest.add
subtract = ServiceTest.subtract
fun_with_async = ServiceTest.fun_with_async
add_async = endpoint_to_async(ServiceTest.add)
subtract_async = endpoint_to_async(ServiceTest.subtract)
@pytest.mark.asyncio(loop_scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest)
client = get_service_client(ServiceTestClient)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8
assert await client.add_async(5, 3) == 8
assert await client.subtract_async(10, 4) == 6

View File

@@ -15,6 +15,7 @@ services:
condition: service_healthy
environment:
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
networks:
- app-network
restart: on-failure
@@ -77,6 +78,7 @@ services:
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- RABBITMQ_HOST=rabbitmq
@@ -126,6 +128,7 @@ services:
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
@@ -139,7 +142,7 @@ services:
- NOTIFICATIONMANAGER_HOST=rest_server
- ENCRYPTION_KEY=dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw= # DO NOT USE IN PRODUCTION!!
ports:
- "8002:8000"
- "8002:8002"
networks:
- app-network
@@ -167,6 +170,7 @@ services:
- DATABASEMANAGER_HOST=rest_server
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- DIRECT_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_PASSWORD=password
@@ -201,6 +205,7 @@ services:
# - NEXT_PUBLIC_SUPABASE_URL=http://kong:8000
# - NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
# - DATABASE_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
# - DIRECT_URL=postgresql://agpt_user:pass123@postgres:5432/postgres?connect_timeout=60&schema=platform
# - NEXT_PUBLIC_AGPT_SERVER_URL=http://localhost:8006/api
# - NEXT_PUBLIC_AGPT_WS_SERVER_URL=ws://localhost:8001/ws
# - NEXT_PUBLIC_AGPT_MARKETPLACE_URL=http://localhost:8015/api/v1/market

View File

@@ -24,3 +24,4 @@ GA_MEASUREMENT_ID=G-FH2XK2W4GN
# When running locally, set NEXT_PUBLIC_BEHAVE_AS=CLOUD to use the a locally hosted marketplace (as is typical in development, and the cloud deployment), otherwise set it to LOCAL to have the marketplace open in a new tab
NEXT_PUBLIC_BEHAVE_AS=LOCAL
NEXT_PUBLIC_SHOW_BILLING_PAGE=false

View File

@@ -42,6 +42,7 @@
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.4",
"@radix-ui/react-toast": "^1.2.5",
"@radix-ui/react-tooltip": "^1.1.7",
"@sentry/nextjs": "^9",
@@ -52,7 +53,6 @@
"@xyflow/react": "12.4.2",
"ajv": "^8.17.1",
"boring-avatars": "^1.11.2",
"canvas-confetti": "^1.9.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cmdk": "1.0.4",
@@ -70,6 +70,7 @@
"moment": "^2.30.1",
"next": "^14.2.26",
"next-themes": "^0.4.5",
"party-js": "^2.2.0",
"react": "^18",
"react-day-picker": "^9.6.1",
"react-dom": "^18",

View File

@@ -163,14 +163,7 @@ export default function Page() {
</div>
<OnboardingFooter>
<OnboardingButton
className="mb-2"
href="/onboarding/4-agent"
disabled={
state?.integrations.length === 0 &&
isEmptyOrWhitespace(state.otherIntegrations)
}
>
<OnboardingButton className="mb-2" href="/onboarding/4-agent">
Next
</OnboardingButton>
</OnboardingFooter>

View File

@@ -59,7 +59,7 @@ export default function Page() {
<div className="my-12 flex items-center justify-between gap-5">
<OnboardingAgentCard
{...(agents[0] || {})}
agent={agents[0]}
selected={
agents[0] !== undefined
? state?.selectedStoreListingVersionId ==
@@ -74,7 +74,7 @@ export default function Page() {
}
/>
<OnboardingAgentCard
{...(agents[1] || {})}
agent={agents[1]}
selected={
agents[1] !== undefined
? state?.selectedStoreListingVersionId ==

View File

@@ -9,7 +9,6 @@ import StarRating from "@/components/onboarding/StarRating";
import { Play } from "lucide-react";
import { cn } from "@/lib/utils";
import { useCallback, useEffect, useState } from "react";
import Image from "next/image";
import { GraphMeta, StoreAgentDetails } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useRouter } from "next/navigation";
@@ -17,6 +16,7 @@ import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import SchemaTooltip from "@/components/SchemaTooltip";
import { TypeBasedInput } from "@/components/type-based-input";
import SmartImage from "@/components/agptui/SmartImage";
export default function Page() {
const { state, updateState, setStep } = useOnboarding(
@@ -80,15 +80,26 @@ export default function Page() {
if (!agent) {
return;
}
api.addMarketplaceAgentToLibrary(
storeAgent?.store_listing_version_id || "",
);
api.executeGraph(agent.id, agent.version, state?.agentInput || {});
router.push("/onboarding/6-congrats");
}, [api, agent, router, state?.agentInput]);
api
.addMarketplaceAgentToLibrary(storeAgent?.store_listing_version_id || "")
.then((libraryAgent) => {
api
.executeGraph(
libraryAgent.graph_id,
libraryAgent.graph_version,
state?.agentInput || {},
)
.then(({ graph_exec_id }) => {
updateState({
onboardingAgentExecutionId: graph_exec_id,
});
router.push("/onboarding/6-congrats");
});
});
}, [api, agent, router, state?.agentInput, storeAgent, updateState]);
const runYourAgent = (
<div className="ml-[54px] w-[481px] pl-5">
<div className="ml-[104px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">Run your first agent</OnboardingText>
<span className="mt-9 text-base font-normal leading-normal text-zinc-600">
@@ -136,32 +147,25 @@ export default function Page() {
return (
<OnboardingStep dotted>
<OnboardingHeader backHref={"/onboarding/4-agent"} transparent />
<div
className={cn(
"flex w-full items-center justify-center",
showInput ? "mt-[32px]" : "mt-[192px]",
)}
>
{/* Left side */}
<div className="mr-[52px] w-[481px]">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
{/* Agent card */}
<div className="fixed left-1/4 top-1/2 w-[481px] -translate-x-1/2 -translate-y-1/2">
<div className="h-[156px] w-[481px] rounded-xl bg-white px-6 pb-5 pt-4">
<span className="font-sans text-xs font-medium tracking-wide text-zinc-500">
SELECTED AGENT
</span>
{storeAgent ? (
<div className="mt-4 flex h-20 rounded-lg bg-violet-50 p-2">
{/* Left image */}
<Image
src={storeAgent?.agent_image[0] || ""}
alt="Description"
width={350}
height={196}
className="h-full w-auto rounded-lg object-contain"
<SmartImage
src={storeAgent?.agent_image[0]}
alt="Agent cover"
imageContain
className="w-[350px] rounded-lg"
/>
{/* Right content */}
<div className="ml-2 flex flex-1 flex-col">
<span className="w-[292px] truncate font-sans text-[14px] font-medium leading-normal text-zinc-800">
{agent?.name}
{storeAgent?.agent_name}
</span>
<span className="mt-[5px] w-[292px] truncate font-sans text-xs font-normal leading-tight text-zinc-600">
by {storeAgent?.creator}
@@ -178,13 +182,19 @@ export default function Page() {
</div>
</div>
</div>
</div>
) : (
<div className="mt-4 flex h-20 animate-pulse rounded-lg bg-gray-300 p-2" />
)}
</div>
</div>
<div className="flex min-h-[80vh] items-center justify-center">
{/* Left side */}
<div className="w-[481px]" />
{/* Right side */}
{!showInput ? (
runYourAgent
) : (
<div className="ml-[54px] w-[481px] pl-5">
<div className="ml-[104px] w-[481px] pl-5">
<div className="flex flex-col">
<OnboardingText variant="header">
Provide details for your agent

View File

@@ -0,0 +1,18 @@
"use server";
import BackendAPI from "@/lib/autogpt-server-api";
import { revalidatePath } from "next/cache";
import { redirect } from "next/navigation";
export async function finishOnboarding() {
const api = new BackendAPI();
const onboarding = await api.getUserOnboarding();
const listingId = onboarding?.selectedStoreListingVersionId;
if (listingId) {
const libraryAgent = await api.addMarketplaceAgentToLibrary(listingId);
revalidatePath(`/library/agents/${libraryAgent.id}`, "layout");
redirect(`/library/agents/${libraryAgent.id}`);
} else {
revalidatePath("/library", "layout");
redirect("/library");
}
}

View File

@@ -1,24 +1,26 @@
"use client";
import { useEffect, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { cn } from "@/lib/utils";
import { finishOnboarding } from "./actions";
import confetti from "canvas-confetti";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import * as party from "party-js";
export default function Page() {
useOnboarding(7, "AGENT_INPUT");
const { state, updateState } = useOnboarding(7, "AGENT_INPUT");
const [showText, setShowText] = useState(false);
const [showSubtext, setShowSubtext] = useState(false);
const divRef = useRef(null);
useEffect(() => {
confetti({
particleCount: 120,
spread: 360,
shapes: ["square", "circle"],
scalar: 2,
decay: 0.93,
origin: { y: 0.38, x: 0.51 },
});
if (divRef.current) {
party.confetti(divRef.current, {
count: 100,
spread: 180,
shapes: ["square", "circle"],
size: party.variation.range(2, 2), // scalar: 2
speed: party.variation.range(300, 1000),
});
}
const timer0 = setTimeout(() => {
setShowText(true);
@@ -29,6 +31,9 @@ export default function Page() {
}, 500);
const timer2 = setTimeout(() => {
updateState({
completedSteps: [...(state?.completedSteps || []), "CONGRATS"],
});
finishOnboarding();
}, 3000);
@@ -42,6 +47,7 @@ export default function Page() {
return (
<div className="flex h-screen w-screen flex-col items-center justify-center bg-violet-100">
<div
ref={divRef}
className={cn(
"z-10 -mb-16 text-9xl duration-500",
showText ? "opacity-100" : "opacity-0",
@@ -63,7 +69,7 @@ export default function Page() {
showSubtext ? "opacity-100" : "opacity-0",
)}
>
You earned 15$ for running your first agent
You earned 3$ for running your first agent
</p>
</div>
);

View File

@@ -7,9 +7,9 @@ export default function OnboardingLayout({
}) {
return (
<div className="flex min-h-screen w-full items-center justify-center bg-gray-100">
<div className="mx-auto flex w-full flex-col items-center">
<main className="mx-auto flex w-full flex-col items-center">
{children}
</div>
</main>
</div>
);
}

View File

@@ -13,7 +13,7 @@ export default async function OnboardingPage() {
// CONGRATS is the last step in intro onboarding
if (onboarding.completedSteps.includes("CONGRATS")) redirect("/marketplace");
else if (onboarding.completedSteps.includes("AGENT_INPUT"))
redirect("/onboarding/6-congrats");
redirect("/onboarding/5-run");
else if (onboarding.completedSteps.includes("AGENT_NEW_RUN"))
redirect("/onboarding/5-run");
else if (onboarding.completedSteps.includes("AGENT_CHOICE"))

View File

@@ -5,11 +5,14 @@ export default async function OnboardingResetPage() {
const api = new BackendAPI();
await api.updateUserOnboarding({
completedSteps: [],
notificationDot: true,
notified: [],
usageReason: null,
integrations: [],
otherIntegrations: "",
selectedStoreListingVersionId: null,
agentInput: {},
onboardingAgentExecutionId: null,
});
redirect("/onboarding/1-welcome");
}

View File

@@ -56,3 +56,9 @@ export async function getAdminListingsWithVersions(
const response = await api.getAdminListingsWithVersions(data);
return response;
}
export async function downloadAsAdmin(storeListingVersion: string) {
const api = new BackendApi();
const file = await api.downloadStoreAgentAdmin(storeListingVersion);
return file;
}

View File

@@ -3,9 +3,16 @@
import { useSearchParams } from "next/navigation";
import { GraphID } from "@/lib/autogpt-server-api/types";
import FlowEditor from "@/components/Flow";
import { useOnboarding } from "@/components/onboarding/onboarding-provider";
import { useEffect } from "react";
export default function Home() {
const query = useSearchParams();
const { completeStep } = useOnboarding();
useEffect(() => {
completeStep("BUILDER_OPEN");
}, []);
return (
<FlowEditor

Some files were not shown because too many files have changed in this diff Show More