mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 06:15:41 -05:00
Compare commits
21 Commits
swiftyos/o
...
fix/execut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b20f4cd13 | ||
|
|
a3d0f9cbd2 | ||
|
|
02ddb51446 | ||
|
|
750e096f15 | ||
|
|
ff5c8f324b | ||
|
|
f121a22544 | ||
|
|
71157bddd7 | ||
|
|
152e747ea6 | ||
|
|
4d4741d558 | ||
|
|
bd37fe946d | ||
|
|
7ff282c908 | ||
|
|
117bb05438 | ||
|
|
979d7c3b74 | ||
|
|
95200b67f8 | ||
|
|
f8afc6044e | ||
|
|
7edf01777e | ||
|
|
c9681f5d44 | ||
|
|
1305325813 | ||
|
|
4f349281bd | ||
|
|
c4eb7edb65 | ||
|
|
3f690ea7b8 |
@@ -0,0 +1,108 @@
|
|||||||
|
{
|
||||||
|
"action": "created",
|
||||||
|
"discussion": {
|
||||||
|
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"category": {
|
||||||
|
"id": 12345678,
|
||||||
|
"node_id": "DIC_kwDOJKSTjM4CXXXX",
|
||||||
|
"repository_id": 614765452,
|
||||||
|
"emoji": ":pray:",
|
||||||
|
"name": "Q&A",
|
||||||
|
"description": "Ask the community for help",
|
||||||
|
"created_at": "2023-03-16T09:21:07Z",
|
||||||
|
"updated_at": "2023-03-16T09:21:07Z",
|
||||||
|
"slug": "q-a",
|
||||||
|
"is_answerable": true
|
||||||
|
},
|
||||||
|
"answer_html_url": null,
|
||||||
|
"answer_chosen_at": null,
|
||||||
|
"answer_chosen_by": null,
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/discussions/9999",
|
||||||
|
"id": 5000000001,
|
||||||
|
"node_id": "D_kwDOJKSTjM4AYYYY",
|
||||||
|
"number": 9999,
|
||||||
|
"title": "How do I configure custom blocks?",
|
||||||
|
"user": {
|
||||||
|
"login": "curious-user",
|
||||||
|
"id": 22222222,
|
||||||
|
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
|
||||||
|
"url": "https://api.github.com/users/curious-user",
|
||||||
|
"html_url": "https://github.com/curious-user",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"state": "open",
|
||||||
|
"state_reason": null,
|
||||||
|
"locked": false,
|
||||||
|
"comments": 0,
|
||||||
|
"created_at": "2024-12-01T17:00:00Z",
|
||||||
|
"updated_at": "2024-12-01T17:00:00Z",
|
||||||
|
"author_association": "NONE",
|
||||||
|
"active_lock_reason": null,
|
||||||
|
"body": "## Question\n\nI'm trying to create a custom block for my specific use case. I've read the documentation but I'm not sure how to:\n\n1. Define the input/output schema\n2. Handle authentication\n3. Test my block locally\n\nCan someone point me to examples or provide guidance?\n\n## Environment\n\n- AutoGPT Platform version: latest\n- Python: 3.11",
|
||||||
|
"reactions": {
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/reactions",
|
||||||
|
"total_count": 0,
|
||||||
|
"+1": 0,
|
||||||
|
"-1": 0,
|
||||||
|
"laugh": 0,
|
||||||
|
"hooray": 0,
|
||||||
|
"confused": 0,
|
||||||
|
"heart": 0,
|
||||||
|
"rocket": 0,
|
||||||
|
"eyes": 0
|
||||||
|
},
|
||||||
|
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/discussions/9999/timeline"
|
||||||
|
},
|
||||||
|
"repository": {
|
||||||
|
"id": 614765452,
|
||||||
|
"node_id": "R_kgDOJKSTjA",
|
||||||
|
"name": "AutoGPT",
|
||||||
|
"full_name": "Significant-Gravitas/AutoGPT",
|
||||||
|
"private": false,
|
||||||
|
"owner": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"url": "https://api.github.com/users/Significant-Gravitas",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas",
|
||||||
|
"type": "Organization",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||||
|
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
|
||||||
|
"fork": false,
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"created_at": "2023-03-16T09:21:07Z",
|
||||||
|
"updated_at": "2024-12-01T17:00:00Z",
|
||||||
|
"pushed_at": "2024-12-01T12:00:00Z",
|
||||||
|
"stargazers_count": 170000,
|
||||||
|
"watchers_count": 170000,
|
||||||
|
"language": "Python",
|
||||||
|
"has_discussions": true,
|
||||||
|
"forks_count": 45000,
|
||||||
|
"visibility": "public",
|
||||||
|
"default_branch": "master"
|
||||||
|
},
|
||||||
|
"organization": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"url": "https://api.github.com/orgs/Significant-Gravitas",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"description": ""
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"login": "curious-user",
|
||||||
|
"id": 22222222,
|
||||||
|
"node_id": "MDQ6VXNlcjIyMjIyMjIy",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/22222222?v=4",
|
||||||
|
"gravatar_id": "",
|
||||||
|
"url": "https://api.github.com/users/curious-user",
|
||||||
|
"html_url": "https://github.com/curious-user",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
{
|
||||||
|
"action": "opened",
|
||||||
|
"issue": {
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345",
|
||||||
|
"repository_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"labels_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/labels{/name}",
|
||||||
|
"comments_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/comments",
|
||||||
|
"events_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/events",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/issues/12345",
|
||||||
|
"id": 2000000001,
|
||||||
|
"node_id": "I_kwDOJKSTjM5wXXXX",
|
||||||
|
"number": 12345,
|
||||||
|
"title": "Bug: Application crashes when processing large files",
|
||||||
|
"user": {
|
||||||
|
"login": "bug-reporter",
|
||||||
|
"id": 11111111,
|
||||||
|
"node_id": "MDQ6VXNlcjExMTExMTEx",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
|
||||||
|
"url": "https://api.github.com/users/bug-reporter",
|
||||||
|
"html_url": "https://github.com/bug-reporter",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"labels": [
|
||||||
|
{
|
||||||
|
"id": 5272676214,
|
||||||
|
"node_id": "LA_kwDOJKSTjM8AAAABOkandg",
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/labels/bug",
|
||||||
|
"name": "bug",
|
||||||
|
"color": "d73a4a",
|
||||||
|
"default": true,
|
||||||
|
"description": "Something isn't working"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"state": "open",
|
||||||
|
"locked": false,
|
||||||
|
"assignee": null,
|
||||||
|
"assignees": [],
|
||||||
|
"milestone": null,
|
||||||
|
"comments": 0,
|
||||||
|
"created_at": "2024-12-01T16:00:00Z",
|
||||||
|
"updated_at": "2024-12-01T16:00:00Z",
|
||||||
|
"closed_at": null,
|
||||||
|
"author_association": "NONE",
|
||||||
|
"active_lock_reason": null,
|
||||||
|
"body": "## Description\n\nWhen I try to process a file larger than 100MB, the application crashes with an out of memory error.\n\n## Steps to Reproduce\n\n1. Open the application\n2. Select a file larger than 100MB\n3. Click 'Process'\n4. Application crashes\n\n## Expected Behavior\n\nThe application should handle large files gracefully.\n\n## Environment\n\n- OS: Ubuntu 22.04\n- Python: 3.11\n- AutoGPT Version: 1.0.0",
|
||||||
|
"reactions": {
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/reactions",
|
||||||
|
"total_count": 0,
|
||||||
|
"+1": 0,
|
||||||
|
"-1": 0,
|
||||||
|
"laugh": 0,
|
||||||
|
"hooray": 0,
|
||||||
|
"confused": 0,
|
||||||
|
"heart": 0,
|
||||||
|
"rocket": 0,
|
||||||
|
"eyes": 0
|
||||||
|
},
|
||||||
|
"timeline_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/issues/12345/timeline",
|
||||||
|
"state_reason": null
|
||||||
|
},
|
||||||
|
"repository": {
|
||||||
|
"id": 614765452,
|
||||||
|
"node_id": "R_kgDOJKSTjA",
|
||||||
|
"name": "AutoGPT",
|
||||||
|
"full_name": "Significant-Gravitas/AutoGPT",
|
||||||
|
"private": false,
|
||||||
|
"owner": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"url": "https://api.github.com/users/Significant-Gravitas",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas",
|
||||||
|
"type": "Organization",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||||
|
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
|
||||||
|
"fork": false,
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"created_at": "2023-03-16T09:21:07Z",
|
||||||
|
"updated_at": "2024-12-01T16:00:00Z",
|
||||||
|
"pushed_at": "2024-12-01T12:00:00Z",
|
||||||
|
"stargazers_count": 170000,
|
||||||
|
"watchers_count": 170000,
|
||||||
|
"language": "Python",
|
||||||
|
"forks_count": 45000,
|
||||||
|
"open_issues_count": 190,
|
||||||
|
"visibility": "public",
|
||||||
|
"default_branch": "master"
|
||||||
|
},
|
||||||
|
"organization": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"url": "https://api.github.com/orgs/Significant-Gravitas",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"description": ""
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"login": "bug-reporter",
|
||||||
|
"id": 11111111,
|
||||||
|
"node_id": "MDQ6VXNlcjExMTExMTEx",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/11111111?v=4",
|
||||||
|
"gravatar_id": "",
|
||||||
|
"url": "https://api.github.com/users/bug-reporter",
|
||||||
|
"html_url": "https://github.com/bug-reporter",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
{
|
||||||
|
"action": "published",
|
||||||
|
"release": {
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789",
|
||||||
|
"assets_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets",
|
||||||
|
"upload_url": "https://uploads.github.com/repos/Significant-Gravitas/AutoGPT/releases/123456789/assets{?name,label}",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/tag/v1.0.0",
|
||||||
|
"id": 123456789,
|
||||||
|
"author": {
|
||||||
|
"login": "ntindle",
|
||||||
|
"id": 12345678,
|
||||||
|
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
|
||||||
|
"gravatar_id": "",
|
||||||
|
"url": "https://api.github.com/users/ntindle",
|
||||||
|
"html_url": "https://github.com/ntindle",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"node_id": "RE_kwDOJKSTjM4HWwAA",
|
||||||
|
"tag_name": "v1.0.0",
|
||||||
|
"target_commitish": "master",
|
||||||
|
"name": "AutoGPT Platform v1.0.0",
|
||||||
|
"draft": false,
|
||||||
|
"prerelease": false,
|
||||||
|
"created_at": "2024-12-01T10:00:00Z",
|
||||||
|
"published_at": "2024-12-01T12:00:00Z",
|
||||||
|
"assets": [
|
||||||
|
{
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/releases/assets/987654321",
|
||||||
|
"id": 987654321,
|
||||||
|
"node_id": "RA_kwDOJKSTjM4HWwBB",
|
||||||
|
"name": "autogpt-v1.0.0.zip",
|
||||||
|
"label": "Release Package",
|
||||||
|
"content_type": "application/zip",
|
||||||
|
"state": "uploaded",
|
||||||
|
"size": 52428800,
|
||||||
|
"download_count": 0,
|
||||||
|
"created_at": "2024-12-01T11:30:00Z",
|
||||||
|
"updated_at": "2024-12-01T11:35:00Z",
|
||||||
|
"browser_download_url": "https://github.com/Significant-Gravitas/AutoGPT/releases/download/v1.0.0/autogpt-v1.0.0.zip"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tarball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/tarball/v1.0.0",
|
||||||
|
"zipball_url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT/zipball/v1.0.0",
|
||||||
|
"body": "## What's New\n\n- Feature 1: Amazing new capability\n- Feature 2: Performance improvements\n- Bug fixes and stability improvements\n\n## Breaking Changes\n\nNone\n\n## Contributors\n\nThanks to all our contributors!"
|
||||||
|
},
|
||||||
|
"repository": {
|
||||||
|
"id": 614765452,
|
||||||
|
"node_id": "R_kgDOJKSTjA",
|
||||||
|
"name": "AutoGPT",
|
||||||
|
"full_name": "Significant-Gravitas/AutoGPT",
|
||||||
|
"private": false,
|
||||||
|
"owner": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"url": "https://api.github.com/users/Significant-Gravitas",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas",
|
||||||
|
"type": "Organization",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||||
|
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
|
||||||
|
"fork": false,
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"created_at": "2023-03-16T09:21:07Z",
|
||||||
|
"updated_at": "2024-12-01T12:00:00Z",
|
||||||
|
"pushed_at": "2024-12-01T12:00:00Z",
|
||||||
|
"stargazers_count": 170000,
|
||||||
|
"watchers_count": 170000,
|
||||||
|
"language": "Python",
|
||||||
|
"forks_count": 45000,
|
||||||
|
"visibility": "public",
|
||||||
|
"default_branch": "master"
|
||||||
|
},
|
||||||
|
"organization": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"url": "https://api.github.com/orgs/Significant-Gravitas",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"description": ""
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"login": "ntindle",
|
||||||
|
"id": 12345678,
|
||||||
|
"node_id": "MDQ6VXNlcjEyMzQ1Njc4",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/12345678?v=4",
|
||||||
|
"gravatar_id": "",
|
||||||
|
"url": "https://api.github.com/users/ntindle",
|
||||||
|
"html_url": "https://github.com/ntindle",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
{
|
||||||
|
"action": "created",
|
||||||
|
"starred_at": "2024-12-01T15:30:00Z",
|
||||||
|
"repository": {
|
||||||
|
"id": 614765452,
|
||||||
|
"node_id": "R_kgDOJKSTjA",
|
||||||
|
"name": "AutoGPT",
|
||||||
|
"full_name": "Significant-Gravitas/AutoGPT",
|
||||||
|
"private": false,
|
||||||
|
"owner": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"url": "https://api.github.com/users/Significant-Gravitas",
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas",
|
||||||
|
"type": "Organization",
|
||||||
|
"site_admin": false
|
||||||
|
},
|
||||||
|
"html_url": "https://github.com/Significant-Gravitas/AutoGPT",
|
||||||
|
"description": "AutoGPT is the vision of accessible AI for everyone, to use and to build on.",
|
||||||
|
"fork": false,
|
||||||
|
"url": "https://api.github.com/repos/Significant-Gravitas/AutoGPT",
|
||||||
|
"created_at": "2023-03-16T09:21:07Z",
|
||||||
|
"updated_at": "2024-12-01T15:30:00Z",
|
||||||
|
"pushed_at": "2024-12-01T12:00:00Z",
|
||||||
|
"stargazers_count": 170001,
|
||||||
|
"watchers_count": 170001,
|
||||||
|
"language": "Python",
|
||||||
|
"forks_count": 45000,
|
||||||
|
"visibility": "public",
|
||||||
|
"default_branch": "master"
|
||||||
|
},
|
||||||
|
"organization": {
|
||||||
|
"login": "Significant-Gravitas",
|
||||||
|
"id": 130738209,
|
||||||
|
"node_id": "O_kgDOB8roIQ",
|
||||||
|
"url": "https://api.github.com/orgs/Significant-Gravitas",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/130738209?v=4",
|
||||||
|
"description": ""
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"login": "awesome-contributor",
|
||||||
|
"id": 98765432,
|
||||||
|
"node_id": "MDQ6VXNlcjk4NzY1NDMy",
|
||||||
|
"avatar_url": "https://avatars.githubusercontent.com/u/98765432?v=4",
|
||||||
|
"gravatar_id": "",
|
||||||
|
"url": "https://api.github.com/users/awesome-contributor",
|
||||||
|
"html_url": "https://github.com/awesome-contributor",
|
||||||
|
"type": "User",
|
||||||
|
"site_admin": false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -159,3 +159,391 @@ class GithubPullRequestTriggerBlock(GitHubTriggerBase, Block):
|
|||||||
|
|
||||||
|
|
||||||
# --8<-- [end:GithubTriggerExample]
|
# --8<-- [end:GithubTriggerExample]
|
||||||
|
|
||||||
|
|
||||||
|
class GithubStarTriggerBlock(GitHubTriggerBase, Block):
|
||||||
|
"""Trigger block for GitHub star events - useful for milestone celebrations."""
|
||||||
|
|
||||||
|
EXAMPLE_PAYLOAD_FILE = (
|
||||||
|
Path(__file__).parent / "example_payloads" / "star.created.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Input(GitHubTriggerBase.Input):
|
||||||
|
class EventsFilter(BaseModel):
|
||||||
|
"""
|
||||||
|
https://docs.github.com/en/webhooks/webhook-events-and-payloads#star
|
||||||
|
"""
|
||||||
|
|
||||||
|
created: bool = False
|
||||||
|
deleted: bool = False
|
||||||
|
|
||||||
|
events: EventsFilter = SchemaField(
|
||||||
|
title="Events", description="The star events to subscribe to"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(GitHubTriggerBase.Output):
|
||||||
|
event: str = SchemaField(
|
||||||
|
description="The star event that triggered the webhook ('created' or 'deleted')"
|
||||||
|
)
|
||||||
|
starred_at: str = SchemaField(
|
||||||
|
description="ISO timestamp when the repo was starred (empty if deleted)"
|
||||||
|
)
|
||||||
|
stargazers_count: int = SchemaField(
|
||||||
|
description="Current number of stars on the repository"
|
||||||
|
)
|
||||||
|
repository_name: str = SchemaField(
|
||||||
|
description="Full name of the repository (owner/repo)"
|
||||||
|
)
|
||||||
|
repository_url: str = SchemaField(description="URL to the repository")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from backend.integrations.webhooks.github import GithubWebhookType
|
||||||
|
|
||||||
|
example_payload = json.loads(
|
||||||
|
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
id="551e0a35-100b-49b7-89b8-3031322239b6",
|
||||||
|
description="This block triggers on GitHub star events. "
|
||||||
|
"Useful for celebrating milestones (e.g., 1k, 10k stars) or tracking engagement.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||||
|
input_schema=GithubStarTriggerBlock.Input,
|
||||||
|
output_schema=GithubStarTriggerBlock.Output,
|
||||||
|
webhook_config=BlockWebhookConfig(
|
||||||
|
provider=ProviderName.GITHUB,
|
||||||
|
webhook_type=GithubWebhookType.REPO,
|
||||||
|
resource_format="{repo}",
|
||||||
|
event_filter_input="events",
|
||||||
|
event_format="star.{event}",
|
||||||
|
),
|
||||||
|
test_input={
|
||||||
|
"repo": "Significant-Gravitas/AutoGPT",
|
||||||
|
"events": {"created": True},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"payload": example_payload,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("payload", example_payload),
|
||||||
|
("triggered_by_user", example_payload["sender"]),
|
||||||
|
("event", example_payload["action"]),
|
||||||
|
("starred_at", example_payload.get("starred_at", "")),
|
||||||
|
("stargazers_count", example_payload["repository"]["stargazers_count"]),
|
||||||
|
("repository_name", example_payload["repository"]["full_name"]),
|
||||||
|
("repository_url", example_payload["repository"]["html_url"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||||
|
async for name, value in super().run(input_data, **kwargs):
|
||||||
|
yield name, value
|
||||||
|
yield "event", input_data.payload["action"]
|
||||||
|
yield "starred_at", input_data.payload.get("starred_at", "")
|
||||||
|
yield "stargazers_count", input_data.payload["repository"]["stargazers_count"]
|
||||||
|
yield "repository_name", input_data.payload["repository"]["full_name"]
|
||||||
|
yield "repository_url", input_data.payload["repository"]["html_url"]
|
||||||
|
|
||||||
|
|
||||||
|
class GithubReleaseTriggerBlock(GitHubTriggerBase, Block):
|
||||||
|
"""Trigger block for GitHub release events - ideal for announcing new versions."""
|
||||||
|
|
||||||
|
EXAMPLE_PAYLOAD_FILE = (
|
||||||
|
Path(__file__).parent / "example_payloads" / "release.published.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Input(GitHubTriggerBase.Input):
|
||||||
|
class EventsFilter(BaseModel):
|
||||||
|
"""
|
||||||
|
https://docs.github.com/en/webhooks/webhook-events-and-payloads#release
|
||||||
|
"""
|
||||||
|
|
||||||
|
published: bool = False
|
||||||
|
unpublished: bool = False
|
||||||
|
created: bool = False
|
||||||
|
edited: bool = False
|
||||||
|
deleted: bool = False
|
||||||
|
prereleased: bool = False
|
||||||
|
released: bool = False
|
||||||
|
|
||||||
|
events: EventsFilter = SchemaField(
|
||||||
|
title="Events", description="The release events to subscribe to"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(GitHubTriggerBase.Output):
|
||||||
|
event: str = SchemaField(
|
||||||
|
description="The release event that triggered the webhook (e.g., 'published')"
|
||||||
|
)
|
||||||
|
release: dict = SchemaField(description="The full release object")
|
||||||
|
release_url: str = SchemaField(description="URL to the release page")
|
||||||
|
tag_name: str = SchemaField(description="The release tag name (e.g., 'v1.0.0')")
|
||||||
|
release_name: str = SchemaField(description="Human-readable release name")
|
||||||
|
body: str = SchemaField(description="Release notes/description")
|
||||||
|
prerelease: bool = SchemaField(description="Whether this is a prerelease")
|
||||||
|
draft: bool = SchemaField(description="Whether this is a draft release")
|
||||||
|
assets: list = SchemaField(description="List of release assets/files")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from backend.integrations.webhooks.github import GithubWebhookType
|
||||||
|
|
||||||
|
example_payload = json.loads(
|
||||||
|
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
id="2052dd1b-74e1-46ac-9c87-c7a0e057b60b",
|
||||||
|
description="This block triggers on GitHub release events. "
|
||||||
|
"Perfect for automating announcements to Discord, Twitter, or other platforms.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||||
|
input_schema=GithubReleaseTriggerBlock.Input,
|
||||||
|
output_schema=GithubReleaseTriggerBlock.Output,
|
||||||
|
webhook_config=BlockWebhookConfig(
|
||||||
|
provider=ProviderName.GITHUB,
|
||||||
|
webhook_type=GithubWebhookType.REPO,
|
||||||
|
resource_format="{repo}",
|
||||||
|
event_filter_input="events",
|
||||||
|
event_format="release.{event}",
|
||||||
|
),
|
||||||
|
test_input={
|
||||||
|
"repo": "Significant-Gravitas/AutoGPT",
|
||||||
|
"events": {"published": True},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"payload": example_payload,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("payload", example_payload),
|
||||||
|
("triggered_by_user", example_payload["sender"]),
|
||||||
|
("event", example_payload["action"]),
|
||||||
|
("release", example_payload["release"]),
|
||||||
|
("release_url", example_payload["release"]["html_url"]),
|
||||||
|
("tag_name", example_payload["release"]["tag_name"]),
|
||||||
|
("release_name", example_payload["release"]["name"]),
|
||||||
|
("body", example_payload["release"]["body"]),
|
||||||
|
("prerelease", example_payload["release"]["prerelease"]),
|
||||||
|
("draft", example_payload["release"]["draft"]),
|
||||||
|
("assets", example_payload["release"]["assets"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||||
|
async for name, value in super().run(input_data, **kwargs):
|
||||||
|
yield name, value
|
||||||
|
release = input_data.payload["release"]
|
||||||
|
yield "event", input_data.payload["action"]
|
||||||
|
yield "release", release
|
||||||
|
yield "release_url", release["html_url"]
|
||||||
|
yield "tag_name", release["tag_name"]
|
||||||
|
yield "release_name", release.get("name", "")
|
||||||
|
yield "body", release.get("body", "")
|
||||||
|
yield "prerelease", release["prerelease"]
|
||||||
|
yield "draft", release["draft"]
|
||||||
|
yield "assets", release["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
class GithubIssuesTriggerBlock(GitHubTriggerBase, Block):
|
||||||
|
"""Trigger block for GitHub issues events - great for triage and notifications."""
|
||||||
|
|
||||||
|
EXAMPLE_PAYLOAD_FILE = (
|
||||||
|
Path(__file__).parent / "example_payloads" / "issues.opened.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Input(GitHubTriggerBase.Input):
|
||||||
|
class EventsFilter(BaseModel):
|
||||||
|
"""
|
||||||
|
https://docs.github.com/en/webhooks/webhook-events-and-payloads#issues
|
||||||
|
"""
|
||||||
|
|
||||||
|
opened: bool = False
|
||||||
|
edited: bool = False
|
||||||
|
deleted: bool = False
|
||||||
|
closed: bool = False
|
||||||
|
reopened: bool = False
|
||||||
|
assigned: bool = False
|
||||||
|
unassigned: bool = False
|
||||||
|
labeled: bool = False
|
||||||
|
unlabeled: bool = False
|
||||||
|
locked: bool = False
|
||||||
|
unlocked: bool = False
|
||||||
|
transferred: bool = False
|
||||||
|
milestoned: bool = False
|
||||||
|
demilestoned: bool = False
|
||||||
|
pinned: bool = False
|
||||||
|
unpinned: bool = False
|
||||||
|
|
||||||
|
events: EventsFilter = SchemaField(
|
||||||
|
title="Events", description="The issue events to subscribe to"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(GitHubTriggerBase.Output):
|
||||||
|
event: str = SchemaField(
|
||||||
|
description="The issue event that triggered the webhook (e.g., 'opened')"
|
||||||
|
)
|
||||||
|
number: int = SchemaField(description="The issue number")
|
||||||
|
issue: dict = SchemaField(description="The full issue object")
|
||||||
|
issue_url: str = SchemaField(description="URL to the issue")
|
||||||
|
issue_title: str = SchemaField(description="The issue title")
|
||||||
|
issue_body: str = SchemaField(description="The issue body/description")
|
||||||
|
labels: list = SchemaField(description="List of labels on the issue")
|
||||||
|
assignees: list = SchemaField(description="List of assignees")
|
||||||
|
state: str = SchemaField(description="Issue state ('open' or 'closed')")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from backend.integrations.webhooks.github import GithubWebhookType
|
||||||
|
|
||||||
|
example_payload = json.loads(
|
||||||
|
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
id="b2605464-e486-4bf4-aad3-d8a213c8a48a",
|
||||||
|
description="This block triggers on GitHub issues events. "
|
||||||
|
"Useful for automated triage, notifications, and welcoming first-time contributors.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||||
|
input_schema=GithubIssuesTriggerBlock.Input,
|
||||||
|
output_schema=GithubIssuesTriggerBlock.Output,
|
||||||
|
webhook_config=BlockWebhookConfig(
|
||||||
|
provider=ProviderName.GITHUB,
|
||||||
|
webhook_type=GithubWebhookType.REPO,
|
||||||
|
resource_format="{repo}",
|
||||||
|
event_filter_input="events",
|
||||||
|
event_format="issues.{event}",
|
||||||
|
),
|
||||||
|
test_input={
|
||||||
|
"repo": "Significant-Gravitas/AutoGPT",
|
||||||
|
"events": {"opened": True},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"payload": example_payload,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("payload", example_payload),
|
||||||
|
("triggered_by_user", example_payload["sender"]),
|
||||||
|
("event", example_payload["action"]),
|
||||||
|
("number", example_payload["issue"]["number"]),
|
||||||
|
("issue", example_payload["issue"]),
|
||||||
|
("issue_url", example_payload["issue"]["html_url"]),
|
||||||
|
("issue_title", example_payload["issue"]["title"]),
|
||||||
|
("issue_body", example_payload["issue"]["body"]),
|
||||||
|
("labels", example_payload["issue"]["labels"]),
|
||||||
|
("assignees", example_payload["issue"]["assignees"]),
|
||||||
|
("state", example_payload["issue"]["state"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||||
|
async for name, value in super().run(input_data, **kwargs):
|
||||||
|
yield name, value
|
||||||
|
issue = input_data.payload["issue"]
|
||||||
|
yield "event", input_data.payload["action"]
|
||||||
|
yield "number", issue["number"]
|
||||||
|
yield "issue", issue
|
||||||
|
yield "issue_url", issue["html_url"]
|
||||||
|
yield "issue_title", issue["title"]
|
||||||
|
yield "issue_body", issue.get("body") or ""
|
||||||
|
yield "labels", issue["labels"]
|
||||||
|
yield "assignees", issue["assignees"]
|
||||||
|
yield "state", issue["state"]
|
||||||
|
|
||||||
|
|
||||||
|
class GithubDiscussionTriggerBlock(GitHubTriggerBase, Block):
|
||||||
|
"""Trigger block for GitHub discussion events - perfect for community Q&A sync."""
|
||||||
|
|
||||||
|
EXAMPLE_PAYLOAD_FILE = (
|
||||||
|
Path(__file__).parent / "example_payloads" / "discussion.created.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Input(GitHubTriggerBase.Input):
|
||||||
|
class EventsFilter(BaseModel):
|
||||||
|
"""
|
||||||
|
https://docs.github.com/en/webhooks/webhook-events-and-payloads#discussion
|
||||||
|
"""
|
||||||
|
|
||||||
|
created: bool = False
|
||||||
|
edited: bool = False
|
||||||
|
deleted: bool = False
|
||||||
|
answered: bool = False
|
||||||
|
unanswered: bool = False
|
||||||
|
labeled: bool = False
|
||||||
|
unlabeled: bool = False
|
||||||
|
locked: bool = False
|
||||||
|
unlocked: bool = False
|
||||||
|
category_changed: bool = False
|
||||||
|
transferred: bool = False
|
||||||
|
pinned: bool = False
|
||||||
|
unpinned: bool = False
|
||||||
|
|
||||||
|
events: EventsFilter = SchemaField(
|
||||||
|
title="Events", description="The discussion events to subscribe to"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(GitHubTriggerBase.Output):
|
||||||
|
event: str = SchemaField(
|
||||||
|
description="The discussion event that triggered the webhook"
|
||||||
|
)
|
||||||
|
number: int = SchemaField(description="The discussion number")
|
||||||
|
discussion: dict = SchemaField(description="The full discussion object")
|
||||||
|
discussion_url: str = SchemaField(description="URL to the discussion")
|
||||||
|
title: str = SchemaField(description="The discussion title")
|
||||||
|
body: str = SchemaField(description="The discussion body")
|
||||||
|
category: dict = SchemaField(description="The discussion category object")
|
||||||
|
category_name: str = SchemaField(description="Name of the category")
|
||||||
|
state: str = SchemaField(description="Discussion state")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from backend.integrations.webhooks.github import GithubWebhookType
|
||||||
|
|
||||||
|
example_payload = json.loads(
|
||||||
|
self.EXAMPLE_PAYLOAD_FILE.read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
id="87f847b3-d81a-424e-8e89-acadb5c9d52b",
|
||||||
|
description="This block triggers on GitHub Discussions events. "
|
||||||
|
"Great for syncing Q&A to Discord or auto-responding to common questions. "
|
||||||
|
"Note: Discussions must be enabled on the repository.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT},
|
||||||
|
input_schema=GithubDiscussionTriggerBlock.Input,
|
||||||
|
output_schema=GithubDiscussionTriggerBlock.Output,
|
||||||
|
webhook_config=BlockWebhookConfig(
|
||||||
|
provider=ProviderName.GITHUB,
|
||||||
|
webhook_type=GithubWebhookType.REPO,
|
||||||
|
resource_format="{repo}",
|
||||||
|
event_filter_input="events",
|
||||||
|
event_format="discussion.{event}",
|
||||||
|
),
|
||||||
|
test_input={
|
||||||
|
"repo": "Significant-Gravitas/AutoGPT",
|
||||||
|
"events": {"created": True},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"payload": example_payload,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[
|
||||||
|
("payload", example_payload),
|
||||||
|
("triggered_by_user", example_payload["sender"]),
|
||||||
|
("event", example_payload["action"]),
|
||||||
|
("number", example_payload["discussion"]["number"]),
|
||||||
|
("discussion", example_payload["discussion"]),
|
||||||
|
("discussion_url", example_payload["discussion"]["html_url"]),
|
||||||
|
("title", example_payload["discussion"]["title"]),
|
||||||
|
("body", example_payload["discussion"]["body"]),
|
||||||
|
("category", example_payload["discussion"]["category"]),
|
||||||
|
("category_name", example_payload["discussion"]["category"]["name"]),
|
||||||
|
("state", example_payload["discussion"]["state"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput: # type: ignore
|
||||||
|
async for name, value in super().run(input_data, **kwargs):
|
||||||
|
yield name, value
|
||||||
|
discussion = input_data.payload["discussion"]
|
||||||
|
yield "event", input_data.payload["action"]
|
||||||
|
yield "number", discussion["number"]
|
||||||
|
yield "discussion", discussion
|
||||||
|
yield "discussion_url", discussion["html_url"]
|
||||||
|
yield "title", discussion["title"]
|
||||||
|
yield "body", discussion.get("body") or ""
|
||||||
|
yield "category", discussion["category"]
|
||||||
|
yield "category_name", discussion["category"]["name"]
|
||||||
|
yield "state", discussion["state"]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from concurrent.futures import Future
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import backend.blocks.llm as llm
|
import backend.blocks.llm as llm
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
@@ -20,16 +23,41 @@ from backend.data.dynamic_fields import (
|
|||||||
is_dynamic_field,
|
is_dynamic_field,
|
||||||
is_tool_pin,
|
is_tool_pin,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.graph import Link, Node
|
from backend.data.graph import Link, Node
|
||||||
|
from backend.executor.manager import ExecutionProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInfo(BaseModel):
|
||||||
|
"""Processed tool call information."""
|
||||||
|
|
||||||
|
tool_call: Any # The original tool call object from LLM response
|
||||||
|
tool_name: str # The function name
|
||||||
|
tool_def: dict[str, Any] # The tool definition from tool_functions
|
||||||
|
input_data: dict[str, Any] # Processed input data ready for tool execution
|
||||||
|
field_mapping: dict[str, str] # Field name mapping for the tool
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionParams(BaseModel):
|
||||||
|
"""Tool execution parameters."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
graph_id: str
|
||||||
|
node_id: str
|
||||||
|
graph_version: int
|
||||||
|
graph_exec_id: str
|
||||||
|
node_exec_id: str
|
||||||
|
execution_context: "ExecutionContext"
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of tool_call_ids if the entry is a tool request.
|
Return a list of tool_call_ids if the entry is a tool request.
|
||||||
@@ -105,6 +133,50 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
|||||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Combine multiple Anthropic tool responses into a single user message.
|
||||||
|
For non-Anthropic formats, returns the original list unchanged.
|
||||||
|
"""
|
||||||
|
if len(tool_outputs) <= 1:
|
||||||
|
return tool_outputs
|
||||||
|
|
||||||
|
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
|
||||||
|
anthropic_responses = [
|
||||||
|
output
|
||||||
|
for output in tool_outputs
|
||||||
|
if (
|
||||||
|
output.get("role") == "user"
|
||||||
|
and output.get("type") == "message"
|
||||||
|
and isinstance(output.get("content"), list)
|
||||||
|
and any(
|
||||||
|
item.get("type") == "tool_result"
|
||||||
|
for item in output.get("content", [])
|
||||||
|
if isinstance(item, dict)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(anthropic_responses) > 1:
|
||||||
|
combined_content = [
|
||||||
|
item for response in anthropic_responses for item in response["content"]
|
||||||
|
]
|
||||||
|
|
||||||
|
combined_response = {
|
||||||
|
"role": "user",
|
||||||
|
"type": "message",
|
||||||
|
"content": combined_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
non_anthropic_responses = [
|
||||||
|
output for output in tool_outputs if output not in anthropic_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
return [combined_response] + non_anthropic_responses
|
||||||
|
|
||||||
|
return tool_outputs
|
||||||
|
|
||||||
|
|
||||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Safely convert raw_response to dictionary format for conversation history.
|
Safely convert raw_response to dictionary format for conversation history.
|
||||||
@@ -204,6 +276,17 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
default="localhost:11434",
|
default="localhost:11434",
|
||||||
description="Ollama host for local models",
|
description="Ollama host for local models",
|
||||||
)
|
)
|
||||||
|
agent_mode_max_iterations: int = SchemaField(
|
||||||
|
title="Agent Mode Max Iterations",
|
||||||
|
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
|
||||||
|
advanced=True,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
conversation_compaction: bool = SchemaField(
|
||||||
|
default=True,
|
||||||
|
title="Context window auto-compaction",
|
||||||
|
description="Automatically compact the context window once it hits the limit",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
@@ -506,6 +589,7 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
Returns the response if successful, raises ValueError if validation fails.
|
Returns the response if successful, raises ValueError if validation fails.
|
||||||
"""
|
"""
|
||||||
resp = await llm.llm_call(
|
resp = await llm.llm_call(
|
||||||
|
compress_prompt_to_fit=input_data.conversation_compaction,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
llm_model=input_data.model,
|
llm_model=input_data.model,
|
||||||
prompt=current_prompt,
|
prompt=current_prompt,
|
||||||
@@ -593,6 +677,291 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def _process_tool_calls(
|
||||||
|
self, response, tool_functions: list[dict[str, Any]]
|
||||||
|
) -> list[ToolInfo]:
|
||||||
|
"""Process tool calls and extract tool definitions, arguments, and input data.
|
||||||
|
|
||||||
|
Returns a list of tool info dicts with:
|
||||||
|
- tool_call: The original tool call object
|
||||||
|
- tool_name: The function name
|
||||||
|
- tool_def: The tool definition from tool_functions
|
||||||
|
- input_data: Processed input data dict (includes None values)
|
||||||
|
- field_mapping: Field name mapping for the tool
|
||||||
|
"""
|
||||||
|
if not response.tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processed_tools = []
|
||||||
|
for tool_call in response.tool_calls:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
tool_def = next(
|
||||||
|
(
|
||||||
|
tool
|
||||||
|
for tool in tool_functions
|
||||||
|
if tool["function"]["name"] == tool_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not tool_def:
|
||||||
|
if len(tool_functions) == 1:
|
||||||
|
tool_def = tool_functions[0]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build input data for the tool
|
||||||
|
input_data = {}
|
||||||
|
field_mapping = tool_def["function"].get("_field_mapping", {})
|
||||||
|
if "function" in tool_def and "parameters" in tool_def["function"]:
|
||||||
|
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||||
|
for clean_arg_name in expected_args:
|
||||||
|
original_field_name = field_mapping.get(
|
||||||
|
clean_arg_name, clean_arg_name
|
||||||
|
)
|
||||||
|
arg_value = tool_args.get(clean_arg_name)
|
||||||
|
# Include all expected parameters, even if None (for backward compatibility with tests)
|
||||||
|
input_data[original_field_name] = arg_value
|
||||||
|
|
||||||
|
processed_tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
tool_call=tool_call,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_def=tool_def,
|
||||||
|
input_data=input_data,
|
||||||
|
field_mapping=field_mapping,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_tools
|
||||||
|
|
||||||
|
def _update_conversation(
|
||||||
|
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||||
|
):
|
||||||
|
"""Update conversation history with response and tool outputs."""
|
||||||
|
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||||
|
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||||
|
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||||
|
item.get("type") == "tool_use"
|
||||||
|
for item in assistant_message.get("content", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.reasoning and not has_tool_calls:
|
||||||
|
prompt.append(
|
||||||
|
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt.append(assistant_message)
|
||||||
|
|
||||||
|
if tool_outputs:
|
||||||
|
prompt.extend(tool_outputs)
|
||||||
|
|
||||||
|
async def _execute_single_tool_with_manager(
|
||||||
|
self,
|
||||||
|
tool_info: ToolInfo,
|
||||||
|
execution_params: ExecutionParams,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
|
) -> dict:
|
||||||
|
"""Execute a single tool using the execution manager for proper integration."""
|
||||||
|
# Lazy imports to avoid circular dependencies
|
||||||
|
from backend.data.execution import NodeExecutionEntry
|
||||||
|
|
||||||
|
tool_call = tool_info.tool_call
|
||||||
|
tool_def = tool_info.tool_def
|
||||||
|
raw_input_data = tool_info.input_data
|
||||||
|
|
||||||
|
# Get sink node and field mapping
|
||||||
|
sink_node_id = tool_def["function"]["_sink_node_id"]
|
||||||
|
|
||||||
|
# Use proper database operations for tool execution
|
||||||
|
db_client = get_database_manager_async_client()
|
||||||
|
|
||||||
|
# Get target node
|
||||||
|
target_node = await db_client.get_node(sink_node_id)
|
||||||
|
if not target_node:
|
||||||
|
raise ValueError(f"Target node {sink_node_id} not found")
|
||||||
|
|
||||||
|
# Create proper node execution using upsert_execution_input
|
||||||
|
node_exec_result = None
|
||||||
|
final_input_data = None
|
||||||
|
|
||||||
|
# Add all inputs to the execution
|
||||||
|
if not raw_input_data:
|
||||||
|
raise ValueError(f"Tool call has no input data: {tool_call}")
|
||||||
|
|
||||||
|
for input_name, input_value in raw_input_data.items():
|
||||||
|
node_exec_result, final_input_data = await db_client.upsert_execution_input(
|
||||||
|
node_id=sink_node_id,
|
||||||
|
graph_exec_id=execution_params.graph_exec_id,
|
||||||
|
input_name=input_name,
|
||||||
|
input_data=input_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||||
|
|
||||||
|
# Create NodeExecutionEntry for execution manager
|
||||||
|
node_exec_entry = NodeExecutionEntry(
|
||||||
|
user_id=execution_params.user_id,
|
||||||
|
graph_exec_id=execution_params.graph_exec_id,
|
||||||
|
graph_id=execution_params.graph_id,
|
||||||
|
graph_version=execution_params.graph_version,
|
||||||
|
node_exec_id=node_exec_result.node_exec_id,
|
||||||
|
node_id=sink_node_id,
|
||||||
|
block_id=target_node.block_id,
|
||||||
|
inputs=final_input_data or {},
|
||||||
|
execution_context=execution_params.execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the execution manager to execute the tool node
|
||||||
|
try:
|
||||||
|
# Get NodeExecutionProgress from the execution manager's running nodes
|
||||||
|
node_exec_progress = execution_processor.running_node_execution[
|
||||||
|
sink_node_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use the execution manager's own graph stats
|
||||||
|
graph_stats_pair = (
|
||||||
|
execution_processor.execution_stats,
|
||||||
|
execution_processor.execution_stats_lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a completed future for the task tracking system
|
||||||
|
node_exec_future = Future()
|
||||||
|
node_exec_progress.add_task(
|
||||||
|
node_exec_id=node_exec_result.node_exec_id,
|
||||||
|
task=node_exec_future,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||||
|
node_exec_future.set_result(
|
||||||
|
await execution_processor.on_node_execution(
|
||||||
|
node_exec=node_exec_entry,
|
||||||
|
node_exec_progress=node_exec_progress,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
graph_stats_pair=graph_stats_pair,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get outputs from database after execution completes using database manager client
|
||||||
|
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||||
|
node_exec_result.node_exec_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tool response
|
||||||
|
tool_response_content = (
|
||||||
|
json.dumps(node_outputs)
|
||||||
|
if node_outputs
|
||||||
|
else "Tool executed successfully"
|
||||||
|
)
|
||||||
|
return _create_tool_response(tool_call.id, tool_response_content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution with manager failed: {e}")
|
||||||
|
# Return error response
|
||||||
|
return _create_tool_response(
|
||||||
|
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_tools_agent_mode(
|
||||||
|
self,
|
||||||
|
input_data,
|
||||||
|
credentials,
|
||||||
|
tool_functions: list[dict[str, Any]],
|
||||||
|
prompt: list[dict],
|
||||||
|
graph_exec_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
|
):
|
||||||
|
"""Execute tools in agent mode with a loop until finished."""
|
||||||
|
max_iterations = input_data.agent_mode_max_iterations
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
# Execution parameters for tool execution
|
||||||
|
execution_params = ExecutionParams(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
node_id=node_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
execution_context=execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_prompt = list(prompt)
|
||||||
|
|
||||||
|
while max_iterations < 0 or iteration < max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
logger.debug(f"Agent mode iteration {iteration}")
|
||||||
|
|
||||||
|
# Prepare prompt for this iteration
|
||||||
|
iteration_prompt = list(current_prompt)
|
||||||
|
|
||||||
|
# On the last iteration, add a special system message to encourage completion
|
||||||
|
if max_iterations > 0 and iteration == max_iterations:
|
||||||
|
last_iteration_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
|
||||||
|
"Try to complete the task with the information you have. If you cannot fully complete it, "
|
||||||
|
"provide a summary of what you've accomplished and what remains to be done. "
|
||||||
|
"Prefer finishing with a clear response rather than making additional tool calls.",
|
||||||
|
}
|
||||||
|
iteration_prompt.append(last_iteration_message)
|
||||||
|
|
||||||
|
# Get LLM response
|
||||||
|
try:
|
||||||
|
response = await self._attempt_llm_call_with_validation(
|
||||||
|
credentials, input_data, iteration_prompt, tool_functions
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
processed_tools = self._process_tool_calls(response, tool_functions)
|
||||||
|
|
||||||
|
# If no tool calls, we're done
|
||||||
|
if not processed_tools:
|
||||||
|
yield "finished", response.response
|
||||||
|
self._update_conversation(current_prompt, response)
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute tools and collect responses
|
||||||
|
tool_outputs = []
|
||||||
|
for tool_info in processed_tools:
|
||||||
|
try:
|
||||||
|
tool_response = await self._execute_single_tool_with_manager(
|
||||||
|
tool_info, execution_params, execution_processor
|
||||||
|
)
|
||||||
|
tool_outputs.append(tool_response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution failed: {e}")
|
||||||
|
# Create error response for the tool
|
||||||
|
error_response = _create_tool_response(
|
||||||
|
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
tool_outputs.append(error_response)
|
||||||
|
|
||||||
|
tool_outputs = _combine_tool_responses(tool_outputs)
|
||||||
|
|
||||||
|
self._update_conversation(current_prompt, response, tool_outputs)
|
||||||
|
|
||||||
|
# Yield intermediate conversation state
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
|
||||||
|
# If we reach max iterations, yield the current state
|
||||||
|
if max_iterations < 0:
|
||||||
|
yield "finished", f"Agent mode completed after {iteration} iterations"
|
||||||
|
else:
|
||||||
|
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
@@ -603,8 +972,12 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||||
yield "tool_functions", json.dumps(tool_functions)
|
yield "tool_functions", json.dumps(tool_functions)
|
||||||
|
|
||||||
@@ -648,24 +1021,52 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||||
|
|
||||||
prefix = "[Main Objective Prompt]: "
|
|
||||||
|
|
||||||
if input_data.sys_prompt and not any(
|
if input_data.sys_prompt and not any(
|
||||||
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
|
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
for p in prompt
|
||||||
):
|
):
|
||||||
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
|
prompt.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if input_data.prompt and not any(
|
if input_data.prompt and not any(
|
||||||
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
|
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
for p in prompt
|
||||||
):
|
):
|
||||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
prompt.append(
|
||||||
|
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute tools based on the selected mode
|
||||||
|
if input_data.agent_mode_max_iterations != 0:
|
||||||
|
# In agent mode, execute tools directly in a loop until finished
|
||||||
|
async for result in self._execute_tools_agent_mode(
|
||||||
|
input_data=input_data,
|
||||||
|
credentials=credentials,
|
||||||
|
tool_functions=tool_functions,
|
||||||
|
prompt=prompt,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
execution_context=execution_context,
|
||||||
|
execution_processor=execution_processor,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
# One-off mode: single LLM call and yield tool calls for external execution
|
||||||
current_prompt = list(prompt)
|
current_prompt = list(prompt)
|
||||||
max_attempts = max(1, int(input_data.retry))
|
max_attempts = max(1, int(input_data.retry))
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
for attempt in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
response = await self._attempt_llm_call_with_validation(
|
response = await self._attempt_llm_call_with_validation(
|
||||||
credentials, input_data, current_prompt, tool_functions
|
credentials, input_data, current_prompt, tool_functions
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ProviderName, User
|
from backend.data.model import ProviderName, User
|
||||||
from backend.server.model import CreateGraph
|
from backend.server.model import CreateGraph
|
||||||
from backend.server.rest_api import AgentServer
|
from backend.server.rest_api import AgentServer
|
||||||
@@ -17,10 +21,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
|
|||||||
|
|
||||||
|
|
||||||
async def create_credentials(s: SpinTestServer, u: User):
|
async def create_credentials(s: SpinTestServer, u: User):
|
||||||
import backend.blocks.llm as llm
|
import backend.blocks.llm as llm_module
|
||||||
|
|
||||||
provider = ProviderName.OPENAI
|
provider = ProviderName.OPENAI
|
||||||
credentials = llm.TEST_CREDENTIALS
|
credentials = llm_module.TEST_CREDENTIALS
|
||||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||||
|
|
||||||
|
|
||||||
@@ -196,8 +200,6 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_tracks_llm_stats():
|
async def test_smart_decision_maker_tracks_llm_stats():
|
||||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -216,7 +218,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Mock the _create_tool_node_signatures method to avoid database calls
|
# Mock the _create_tool_node_signatures method to avoid database calls
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
@@ -234,10 +235,19 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
prompt="Should I continue with this task?",
|
prompt="Should I continue with this task?",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the block
|
# Execute the block
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -246,6 +256,9 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -263,8 +276,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_parameter_validation():
|
async def test_smart_decision_maker_parameter_validation():
|
||||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -311,8 +322,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_with_typo.reasoning = None
|
mock_response_with_typo.reasoning = None
|
||||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -329,8 +338,17 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
retry=2, # Set retry to 2 for testing
|
retry=2, # Set retry to 2 for testing
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
# Should raise ValueError after retries due to typo'd parameter name
|
# Should raise ValueError after retries due to typo'd parameter name
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -342,6 +360,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -368,8 +389,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_missing_required.reasoning = None
|
mock_response_missing_required.reasoning = None
|
||||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -385,8 +404,17 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
# Should raise ValueError due to missing required parameter
|
# Should raise ValueError due to missing required parameter
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -398,6 +426,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -418,8 +449,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_valid.reasoning = None
|
mock_response_valid.reasoning = None
|
||||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -435,10 +464,19 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed - optional parameter missing is OK
|
# Should succeed - optional parameter missing is OK
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -447,6 +485,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -472,8 +513,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_all_params.reasoning = None
|
mock_response_all_params.reasoning = None
|
||||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -489,10 +528,19 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed with all parameters
|
# Should succeed with all parameters
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -501,6 +549,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -513,8 +564,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_raw_response_conversion():
|
async def test_smart_decision_maker_raw_response_conversion():
|
||||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -584,7 +633,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock llm_call to return different responses on different calls
|
# Mock llm_call to return different responses on different calls
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||||
@@ -603,10 +651,19 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
retry=2,
|
retry=2,
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed after retry, demonstrating our helper function works
|
# Should succeed after retry, demonstrating our helper function works
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -615,6 +672,9 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -650,8 +710,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
"I'll help you with that." # Ollama returns string
|
"I'll help you with that." # Ollama returns string
|
||||||
)
|
)
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -666,9 +724,18 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
prompt="Simple prompt",
|
prompt="Simple prompt",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -677,6 +744,9 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -696,8 +766,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
"content": "Test response",
|
"content": "Test response",
|
||||||
} # Dict format
|
} # Dict format
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -712,6 +780,160 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
prompt="Another test",
|
prompt="Another test",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
|
async for output_name, output_data in block.run(
|
||||||
|
input_data,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
|
graph_id="test-graph-id",
|
||||||
|
node_id="test-node-id",
|
||||||
|
graph_exec_id="test-exec-id",
|
||||||
|
node_exec_id="test-node-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
|
):
|
||||||
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
assert "finished" in outputs
|
||||||
|
assert outputs["finished"] == "Test response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_decision_maker_agent_mode():
|
||||||
|
"""Test that agent mode executes tools directly and loops until finished."""
|
||||||
|
import backend.blocks.llm as llm_module
|
||||||
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
|
block = SmartDecisionMakerBlock()
|
||||||
|
|
||||||
|
# Mock tool call that requires multiple iterations
|
||||||
|
mock_tool_call_1 = MagicMock()
|
||||||
|
mock_tool_call_1.id = "call_1"
|
||||||
|
mock_tool_call_1.function.name = "search_keywords"
|
||||||
|
mock_tool_call_1.function.arguments = (
|
||||||
|
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response_1 = MagicMock()
|
||||||
|
mock_response_1.response = None
|
||||||
|
mock_response_1.tool_calls = [mock_tool_call_1]
|
||||||
|
mock_response_1.prompt_tokens = 50
|
||||||
|
mock_response_1.completion_tokens = 25
|
||||||
|
mock_response_1.reasoning = "Using search tool"
|
||||||
|
mock_response_1.raw_response = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "call_1", "type": "function"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Final response with no tool calls (finished)
|
||||||
|
mock_response_2 = MagicMock()
|
||||||
|
mock_response_2.response = "Task completed successfully"
|
||||||
|
mock_response_2.tool_calls = []
|
||||||
|
mock_response_2.prompt_tokens = 30
|
||||||
|
mock_response_2.completion_tokens = 15
|
||||||
|
mock_response_2.reasoning = None
|
||||||
|
mock_response_2.raw_response = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Task completed successfully",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the LLM call to return different responses on each iteration
|
||||||
|
llm_call_mock = AsyncMock()
|
||||||
|
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
|
||||||
|
|
||||||
|
# Mock tool node signatures
|
||||||
|
mock_tool_signatures = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_keywords",
|
||||||
|
"_sink_node_id": "test-sink-node-id",
|
||||||
|
"_field_mapping": {},
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"max_keyword_difficulty": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query", "max_keyword_difficulty"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock database and execution components
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_node = MagicMock()
|
||||||
|
mock_node.block_id = "test-block-id"
|
||||||
|
mock_db_client.get_node.return_value = mock_node
|
||||||
|
|
||||||
|
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
|
||||||
|
mock_node_exec_result = MagicMock()
|
||||||
|
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||||
|
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
|
||||||
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
|
mock_node_exec_result,
|
||||||
|
mock_input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No longer need mock_execute_node since we use execution_processor.on_node_execution
|
||||||
|
|
||||||
|
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||||
|
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||||
|
), patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||||
|
return_value=mock_db_client,
|
||||||
|
), patch(
|
||||||
|
"backend.executor.manager.async_update_node_execution_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
), patch(
|
||||||
|
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||||
|
):
|
||||||
|
|
||||||
|
# Create a mock execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(
|
||||||
|
safe_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock execution processor for agent mode tests
|
||||||
|
|
||||||
|
mock_execution_processor = AsyncMock()
|
||||||
|
# Configure the execution processor mock with required attributes
|
||||||
|
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||||
|
mock_execution_processor.execution_stats = MagicMock()
|
||||||
|
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Mock the on_node_execution method to return successful stats
|
||||||
|
mock_node_stats = MagicMock()
|
||||||
|
mock_node_stats.error = None # No error
|
||||||
|
mock_execution_processor.on_node_execution = AsyncMock(
|
||||||
|
return_value=mock_node_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the get_execution_outputs_by_node_exec_id method
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||||
|
"result": {"status": "success", "data": "search completed"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test agent mode with max_iterations = 3
|
||||||
|
input_data = SmartDecisionMakerBlock.Input(
|
||||||
|
prompt="Complete this task using tools",
|
||||||
|
model=llm_module.LlmModel.GPT4O,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -723,8 +945,115 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
# Verify agent mode behavior
|
||||||
|
assert "tool_functions" in outputs # tool_functions is yielded in both modes
|
||||||
assert "finished" in outputs
|
assert "finished" in outputs
|
||||||
assert outputs["finished"] == "Test response"
|
assert outputs["finished"] == "Task completed successfully"
|
||||||
|
assert "conversations" in outputs
|
||||||
|
|
||||||
|
# Verify the conversation includes tool responses
|
||||||
|
conversations = outputs["conversations"]
|
||||||
|
assert len(conversations) > 2 # Should have multiple conversation entries
|
||||||
|
|
||||||
|
# Verify LLM was called twice (once for tool call, once for finish)
|
||||||
|
assert llm_call_mock.call_count == 2
|
||||||
|
|
||||||
|
# Verify tool was executed via execution processor
|
||||||
|
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_decision_maker_traditional_mode_default():
|
||||||
|
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||||
|
import backend.blocks.llm as llm_module
|
||||||
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
|
block = SmartDecisionMakerBlock()
|
||||||
|
|
||||||
|
# Mock tool call
|
||||||
|
mock_tool_call = MagicMock()
|
||||||
|
mock_tool_call.function.name = "search_keywords"
|
||||||
|
mock_tool_call.function.arguments = (
|
||||||
|
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.response = None
|
||||||
|
mock_response.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.prompt_tokens = 50
|
||||||
|
mock_response.completion_tokens = 25
|
||||||
|
mock_response.reasoning = None
|
||||||
|
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
|
mock_tool_signatures = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_keywords",
|
||||||
|
"_sink_node_id": "test-sink-node-id",
|
||||||
|
"_field_mapping": {},
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"max_keyword_difficulty": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query", "max_keyword_difficulty"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.blocks.llm.llm_call",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
), patch.object(
|
||||||
|
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||||
|
):
|
||||||
|
|
||||||
|
# Test default behavior (traditional mode)
|
||||||
|
input_data = SmartDecisionMakerBlock.Input(
|
||||||
|
prompt="Test prompt",
|
||||||
|
model=llm_module.LlmModel.GPT4O,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0, # Traditional mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
async for output_name, output_data in block.run(
|
||||||
|
input_data,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
|
graph_id="test-graph-id",
|
||||||
|
node_id="test-node-id",
|
||||||
|
graph_exec_id="test-exec-id",
|
||||||
|
node_exec_id="test-node-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
|
):
|
||||||
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
# Verify traditional mode behavior
|
||||||
|
assert (
|
||||||
|
"tool_functions" in outputs
|
||||||
|
) # Should yield tool_functions in traditional mode
|
||||||
|
assert (
|
||||||
|
"tools_^_test-sink-node-id_~_query" in outputs
|
||||||
|
) # Should yield individual tool parameters
|
||||||
|
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||||
|
assert "conversations" in outputs
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -308,10 +308,47 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
) as mock_llm:
|
) as mock_llm:
|
||||||
mock_llm.return_value = mock_response
|
mock_llm.return_value = mock_response
|
||||||
|
|
||||||
# Mock the function signature creation
|
# Mock the database manager to avoid HTTP calls during tool execution
|
||||||
with patch.object(
|
with patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||||
|
) as mock_db_manager, patch.object(
|
||||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||||
) as mock_sig:
|
) as mock_sig:
|
||||||
|
# Set up the mock database manager
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_db_manager.return_value = mock_db_client
|
||||||
|
|
||||||
|
# Mock the node retrieval
|
||||||
|
mock_target_node = Mock()
|
||||||
|
mock_target_node.id = "test-sink-node-id"
|
||||||
|
mock_target_node.block_id = "CreateDictionaryBlock"
|
||||||
|
mock_target_node.block = Mock()
|
||||||
|
mock_target_node.block.name = "Create Dictionary"
|
||||||
|
mock_db_client.get_node.return_value = mock_target_node
|
||||||
|
|
||||||
|
# Mock the execution result creation
|
||||||
|
mock_node_exec_result = Mock()
|
||||||
|
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||||
|
mock_final_input_data = {
|
||||||
|
"values_#_name": "Alice",
|
||||||
|
"values_#_age": 30,
|
||||||
|
"values_#_email": "alice@example.com",
|
||||||
|
}
|
||||||
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
|
mock_node_exec_result,
|
||||||
|
mock_final_input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the output retrieval
|
||||||
|
mock_outputs = {
|
||||||
|
"values_#_name": "Alice",
|
||||||
|
"values_#_age": 30,
|
||||||
|
"values_#_email": "alice@example.com",
|
||||||
|
}
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||||
|
mock_outputs
|
||||||
|
)
|
||||||
|
|
||||||
mock_sig.return_value = [
|
mock_sig.return_value = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -337,10 +374,16 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
prompt="Create a user dictionary",
|
prompt="Create a user dictionary",
|
||||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||||
model=llm.LlmModel.GPT4O,
|
model=llm.LlmModel.GPT4O,
|
||||||
|
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the block
|
# Run the block
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_value in block.run(
|
async for output_name, output_value in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm.TEST_CREDENTIALS,
|
credentials=llm.TEST_CREDENTIALS,
|
||||||
@@ -349,6 +392,9 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
graph_exec_id="test_exec",
|
graph_exec_id="test_exec",
|
||||||
node_exec_id="test_node_exec",
|
node_exec_id="test_node_exec",
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_value
|
outputs[output_name] = output_value
|
||||||
|
|
||||||
@@ -511,45 +557,108 @@ async def test_validation_errors_dont_pollute_conversation():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create input data
|
# Mock the database manager to avoid HTTP calls during tool execution
|
||||||
from backend.blocks import llm
|
with patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||||
|
) as mock_db_manager:
|
||||||
|
# Set up the mock database manager for agent mode
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_db_manager.return_value = mock_db_client
|
||||||
|
|
||||||
input_data = block.input_schema(
|
# Mock the node retrieval
|
||||||
prompt="Test prompt",
|
mock_target_node = Mock()
|
||||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
mock_target_node.id = "test-sink-node-id"
|
||||||
model=llm.LlmModel.GPT4O,
|
mock_target_node.block_id = "TestBlock"
|
||||||
retry=3, # Allow retries
|
mock_target_node.block = Mock()
|
||||||
)
|
mock_target_node.block.name = "Test Block"
|
||||||
|
mock_db_client.get_node.return_value = mock_target_node
|
||||||
|
|
||||||
# Run the block
|
# Mock the execution result creation
|
||||||
outputs = {}
|
mock_node_exec_result = Mock()
|
||||||
async for output_name, output_value in block.run(
|
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||||
input_data,
|
mock_final_input_data = {"correct_param": "value"}
|
||||||
credentials=llm.TEST_CREDENTIALS,
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
graph_id="test_graph",
|
mock_node_exec_result,
|
||||||
node_id="test_node",
|
mock_final_input_data,
|
||||||
graph_exec_id="test_exec",
|
)
|
||||||
node_exec_id="test_node_exec",
|
|
||||||
user_id="test_user",
|
|
||||||
):
|
|
||||||
outputs[output_name] = output_value
|
|
||||||
|
|
||||||
# Verify we had 2 LLM calls (initial + retry)
|
# Mock the output retrieval
|
||||||
assert call_count == 2
|
mock_outputs = {"correct_param": "value"}
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||||
|
mock_outputs
|
||||||
|
)
|
||||||
|
|
||||||
# Check the final conversation output
|
# Create input data
|
||||||
final_conversation = outputs.get("conversations", [])
|
from backend.blocks import llm
|
||||||
|
|
||||||
# The final conversation should NOT contain the validation error message
|
input_data = block.input_schema(
|
||||||
error_messages = [
|
prompt="Test prompt",
|
||||||
msg
|
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||||
for msg in final_conversation
|
model=llm.LlmModel.GPT4O,
|
||||||
if msg.get("role") == "user"
|
retry=3, # Allow retries
|
||||||
and "parameter errors" in msg.get("content", "")
|
agent_mode_max_iterations=1,
|
||||||
]
|
)
|
||||||
assert (
|
|
||||||
len(error_messages) == 0
|
|
||||||
), "Validation error leaked into final conversation"
|
|
||||||
|
|
||||||
# The final conversation should only have the successful response
|
# Run the block
|
||||||
assert final_conversation[-1]["content"] == "valid"
|
outputs = {}
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a proper mock execution processor for agent mode
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
mock_execution_processor = AsyncMock()
|
||||||
|
mock_execution_processor.execution_stats = MagicMock()
|
||||||
|
mock_execution_processor.execution_stats_lock = MagicMock()
|
||||||
|
|
||||||
|
# Create a mock NodeExecutionProgress for the sink node
|
||||||
|
mock_node_exec_progress = MagicMock()
|
||||||
|
mock_node_exec_progress.add_task = MagicMock()
|
||||||
|
mock_node_exec_progress.pop_output = MagicMock(
|
||||||
|
return_value=None
|
||||||
|
) # No outputs to process
|
||||||
|
|
||||||
|
# Set up running_node_execution as a defaultdict that returns our mock for any key
|
||||||
|
mock_execution_processor.running_node_execution = defaultdict(
|
||||||
|
lambda: mock_node_exec_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the on_node_execution method that gets called during tool execution
|
||||||
|
mock_node_stats = MagicMock()
|
||||||
|
mock_node_stats.error = None
|
||||||
|
mock_execution_processor.on_node_execution.return_value = (
|
||||||
|
mock_node_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
async for output_name, output_value in block.run(
|
||||||
|
input_data,
|
||||||
|
credentials=llm.TEST_CREDENTIALS,
|
||||||
|
graph_id="test_graph",
|
||||||
|
node_id="test_node",
|
||||||
|
graph_exec_id="test_exec",
|
||||||
|
node_exec_id="test_node_exec",
|
||||||
|
user_id="test_user",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
|
):
|
||||||
|
outputs[output_name] = output_value
|
||||||
|
|
||||||
|
# Verify we had at least 1 LLM call
|
||||||
|
assert call_count >= 1
|
||||||
|
|
||||||
|
# Check the final conversation output
|
||||||
|
final_conversation = outputs.get("conversations", [])
|
||||||
|
|
||||||
|
# The final conversation should NOT contain validation error messages
|
||||||
|
# Even if retries don't happen in agent mode, we should not leak errors
|
||||||
|
error_messages = [
|
||||||
|
msg
|
||||||
|
for msg in final_conversation
|
||||||
|
if msg.get("role") == "user"
|
||||||
|
and "parameter errors" in msg.get("content", "")
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(error_messages) == 0
|
||||||
|
), "Validation error leaked into final conversation"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||||
from prisma.models import APIKey as PrismaAPIKey
|
from prisma.models import APIKey as PrismaAPIKey
|
||||||
from prisma.types import APIKeyCreateInput, APIKeyWhereUniqueInput
|
from prisma.types import APIKeyWhereUniqueInput
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
from backend.data.includes import MAX_USER_API_KEYS_FETCH
|
||||||
@@ -83,17 +83,17 @@ async def create_api_key(
|
|||||||
generated_key = keysmith.generate_key()
|
generated_key = keysmith.generate_key()
|
||||||
|
|
||||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||||
data=APIKeyCreateInput(
|
data={
|
||||||
id=str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
name=name,
|
"name": name,
|
||||||
head=generated_key.head,
|
"head": generated_key.head,
|
||||||
tail=generated_key.tail,
|
"tail": generated_key.tail,
|
||||||
hash=generated_key.hash,
|
"hash": generated_key.hash,
|
||||||
salt=generated_key.salt,
|
"salt": generated_key.salt,
|
||||||
permissions=permissions,
|
"permissions": [p for p in permissions],
|
||||||
description=description,
|
"description": description,
|
||||||
userId=user_id,
|
"userId": user_id,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||||
|
|||||||
@@ -1,327 +0,0 @@
|
|||||||
"""
|
|
||||||
Credential Grant data layer.
|
|
||||||
|
|
||||||
Handles database operations for credential grants which allow OAuth clients
|
|
||||||
to use credentials on behalf of users.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from prisma.enums import CredentialGrantPermission
|
|
||||||
from prisma.models import CredentialGrant
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
|
|
||||||
|
|
||||||
async def create_credential_grant(
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
credential_id: str,
|
|
||||||
provider: str,
|
|
||||||
granted_scopes: list[str],
|
|
||||||
permissions: list[CredentialGrantPermission],
|
|
||||||
expires_at: Optional[datetime] = None,
|
|
||||||
) -> CredentialGrant:
|
|
||||||
"""
|
|
||||||
Create a new credential grant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user granting access
|
|
||||||
client_id: Database ID of the OAuth client
|
|
||||||
credential_id: ID of the credential being granted
|
|
||||||
provider: Provider name (e.g., "google", "github")
|
|
||||||
granted_scopes: List of integration scopes granted
|
|
||||||
permissions: List of permissions (USE, DELETE)
|
|
||||||
expires_at: Optional expiration datetime
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created CredentialGrant
|
|
||||||
"""
|
|
||||||
return await prisma.credentialgrant.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
"credentialId": credential_id,
|
|
||||||
"provider": provider,
|
|
||||||
"grantedScopes": granted_scopes,
|
|
||||||
"permissions": permissions,
|
|
||||||
"expiresAt": expires_at,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_credential_grant(
|
|
||||||
grant_id: str,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
) -> Optional[CredentialGrant]:
|
|
||||||
"""
|
|
||||||
Get a credential grant by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
user_id: Optional user ID filter
|
|
||||||
client_id: Optional client database ID filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CredentialGrant or None
|
|
||||||
"""
|
|
||||||
where: dict[str, str] = {"id": grant_id}
|
|
||||||
if user_id:
|
|
||||||
where["userId"] = user_id
|
|
||||||
if client_id:
|
|
||||||
where["clientId"] = client_id
|
|
||||||
|
|
||||||
return await prisma.credentialgrant.find_first(where=where) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_grants_for_user_client(
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
include_revoked: bool = False,
|
|
||||||
include_expired: bool = False,
|
|
||||||
) -> list[CredentialGrant]:
|
|
||||||
"""
|
|
||||||
Get all credential grants for a user-client pair.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
client_id: Client database ID
|
|
||||||
include_revoked: Include revoked grants
|
|
||||||
include_expired: Include expired grants
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CredentialGrant objects
|
|
||||||
"""
|
|
||||||
where: dict[str, str | None] = {
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not include_revoked:
|
|
||||||
where["revokedAt"] = None
|
|
||||||
|
|
||||||
grants = await prisma.credentialgrant.find_many(
|
|
||||||
where=where, # type: ignore[arg-type]
|
|
||||||
order={"createdAt": "desc"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter expired if needed
|
|
||||||
if not include_expired:
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
grants = [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
|
||||||
|
|
||||||
return grants
|
|
||||||
|
|
||||||
|
|
||||||
async def get_grants_for_credential(
|
|
||||||
user_id: str,
|
|
||||||
credential_id: str,
|
|
||||||
) -> list[CredentialGrant]:
|
|
||||||
"""
|
|
||||||
Get all active grants for a specific credential.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
credential_id: Credential ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of active CredentialGrant objects
|
|
||||||
"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
grants = await prisma.credentialgrant.find_many(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"credentialId": credential_id,
|
|
||||||
"revokedAt": None,
|
|
||||||
},
|
|
||||||
include={"Client": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter expired
|
|
||||||
return [g for g in grants if g.expiresAt is None or g.expiresAt > now]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_grant_by_credential_and_client(
|
|
||||||
user_id: str,
|
|
||||||
credential_id: str,
|
|
||||||
client_id: str,
|
|
||||||
) -> Optional[CredentialGrant]:
|
|
||||||
"""
|
|
||||||
Get the grant for a specific credential and client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
credential_id: Credential ID
|
|
||||||
client_id: Client database ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CredentialGrant or None
|
|
||||||
"""
|
|
||||||
return await prisma.credentialgrant.find_first(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"credentialId": credential_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
"revokedAt": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_grant_scopes(
|
|
||||||
grant_id: str,
|
|
||||||
granted_scopes: list[str],
|
|
||||||
) -> CredentialGrant:
|
|
||||||
"""
|
|
||||||
Update the granted scopes for a credential grant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
granted_scopes: New list of granted scopes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated CredentialGrant
|
|
||||||
"""
|
|
||||||
result = await prisma.credentialgrant.update(
|
|
||||||
where={"id": grant_id},
|
|
||||||
data={"grantedScopes": granted_scopes},
|
|
||||||
)
|
|
||||||
if result is None:
|
|
||||||
raise ValueError(f"Grant {grant_id} not found")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def update_grant_last_used(grant_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Update the lastUsedAt timestamp for a grant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
"""
|
|
||||||
await prisma.credentialgrant.update(
|
|
||||||
where={"id": grant_id},
|
|
||||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def revoke_grant(grant_id: str) -> CredentialGrant:
|
|
||||||
"""
|
|
||||||
Revoke a credential grant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Revoked CredentialGrant
|
|
||||||
"""
|
|
||||||
result = await prisma.credentialgrant.update(
|
|
||||||
where={"id": grant_id},
|
|
||||||
data={"revokedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
if result is None:
|
|
||||||
raise ValueError(f"Grant {grant_id} not found")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def revoke_grants_for_credential(
|
|
||||||
user_id: str,
|
|
||||||
credential_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Revoke all grants for a specific credential.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
credential_id: Credential ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of grants revoked
|
|
||||||
"""
|
|
||||||
return await prisma.credentialgrant.update_many(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"credentialId": credential_id,
|
|
||||||
"revokedAt": None,
|
|
||||||
},
|
|
||||||
data={"revokedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def revoke_grants_for_client(
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Revoke all grants for a specific client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
client_id: Client database ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of grants revoked
|
|
||||||
"""
|
|
||||||
return await prisma.credentialgrant.update_many(
|
|
||||||
where={
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
"revokedAt": None,
|
|
||||||
},
|
|
||||||
data={"revokedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_grant(grant_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Permanently delete a credential grant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
"""
|
|
||||||
await prisma.credentialgrant.delete(where={"id": grant_id})
|
|
||||||
|
|
||||||
|
|
||||||
async def check_grant_permission(
|
|
||||||
grant_id: str,
|
|
||||||
required_permission: CredentialGrantPermission,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a grant has a specific permission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
required_permission: Permission to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if grant has the permission
|
|
||||||
"""
|
|
||||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
|
||||||
if not grant:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return required_permission in grant.permissions
|
|
||||||
|
|
||||||
|
|
||||||
async def is_grant_valid(grant_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a grant is valid (not revoked and not expired).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
grant_id: Grant ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if grant is valid
|
|
||||||
"""
|
|
||||||
grant = await prisma.credentialgrant.find_unique(where={"id": grant_id})
|
|
||||||
if not grant:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if grant.revokedAt:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
@@ -11,7 +11,6 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
|
||||||
|
|
||||||
from backend.data.credit import UserCredit
|
from backend.data.credit import UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -22,11 +21,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for ceiling tests."""
|
"""Create a test user for ceiling tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=user_id,
|
"id": user_id,
|
||||||
email=f"test-{user_id}@example.com",
|
"email": f"test-{user_id}@example.com",
|
||||||
name=f"Test User {user_id[:8]}",
|
"name": f"Test User {user_id[:8]}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -34,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
|
||||||
update={"balance": 0},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
|
||||||
|
|
||||||
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
from backend.data.credit import POSTGRES_INT_MAX, UsageTransactionMetadata, UserCredit
|
||||||
from backend.util.exceptions import InsufficientBalanceError
|
from backend.util.exceptions import InsufficientBalanceError
|
||||||
@@ -29,11 +28,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user with initial balance."""
|
"""Create a test user with initial balance."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=user_id,
|
"id": user_id,
|
||||||
email=f"test-{user_id}@example.com",
|
"email": f"test-{user_id}@example.com",
|
||||||
name=f"Test User {user_id[:8]}",
|
"name": f"Test User {user_id[:8]}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -42,10 +41,7 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
# Ensure UserBalance record exists
|
# Ensure UserBalance record exists
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
|
||||||
update={"balance": 0},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -346,10 +342,10 @@ async def test_integer_overflow_protection(server: SpinTestServer):
|
|||||||
# First, set balance near max
|
# First, set balance near max
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=max_int - 100),
|
"create": {"userId": user_id, "balance": max_int - 100},
|
||||||
update={"balance": max_int - 100},
|
"update": {"balance": max_int - 100},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
# Try to add more than possible - should clamp to POSTGRES_INT_MAX
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ which would have caught the CreditTransactionType enum casting bug.
|
|||||||
import pytest
|
import pytest
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
from prisma.types import UserCreateInput
|
|
||||||
|
|
||||||
from backend.data.credit import (
|
from backend.data.credit import (
|
||||||
AutoTopUpConfig,
|
AutoTopUpConfig,
|
||||||
@@ -30,12 +29,12 @@ async def cleanup_test_user():
|
|||||||
# Create the user first
|
# Create the user first
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=user_id,
|
"id": user_id,
|
||||||
email=f"test-{user_id}@example.com",
|
"email": f"test-{user_id}@example.com",
|
||||||
topUpConfig=SafeJson({}),
|
"topUpConfig": SafeJson({}),
|
||||||
timezone="UTC",
|
"timezone": "UTC",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# User might already exist, that's fine
|
# User might already exist, that's fine
|
||||||
|
|||||||
@@ -12,12 +12,6 @@ import pytest
|
|||||||
import stripe
|
import stripe
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance
|
||||||
from prisma.types import (
|
|
||||||
CreditRefundRequestCreateInput,
|
|
||||||
CreditTransactionCreateInput,
|
|
||||||
UserBalanceCreateInput,
|
|
||||||
UserCreateInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.data.credit import UserCredit
|
from backend.data.credit import UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -41,32 +35,32 @@ async def setup_test_user_with_topup():
|
|||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=REFUND_TEST_USER_ID,
|
"id": REFUND_TEST_USER_ID,
|
||||||
email=f"{REFUND_TEST_USER_ID}@example.com",
|
"email": f"{REFUND_TEST_USER_ID}@example.com",
|
||||||
name="Refund Test User",
|
"name": "Refund Test User",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create user balance
|
# Create user balance
|
||||||
await UserBalance.prisma().create(
|
await UserBalance.prisma().create(
|
||||||
data=UserBalanceCreateInput(
|
data={
|
||||||
userId=REFUND_TEST_USER_ID,
|
"userId": REFUND_TEST_USER_ID,
|
||||||
balance=1000, # $10
|
"balance": 1000, # $10
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a top-up transaction that can be refunded
|
# Create a top-up transaction that can be refunded
|
||||||
topup_tx = await CreditTransaction.prisma().create(
|
topup_tx = await CreditTransaction.prisma().create(
|
||||||
data=CreditTransactionCreateInput(
|
data={
|
||||||
userId=REFUND_TEST_USER_ID,
|
"userId": REFUND_TEST_USER_ID,
|
||||||
amount=1000,
|
"amount": 1000,
|
||||||
type=CreditTransactionType.TOP_UP,
|
"type": CreditTransactionType.TOP_UP,
|
||||||
transactionKey="pi_test_12345",
|
"transactionKey": "pi_test_12345",
|
||||||
runningBalance=1000,
|
"runningBalance": 1000,
|
||||||
isActive=True,
|
"isActive": True,
|
||||||
metadata=SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
"metadata": SafeJson({"stripe_payment_intent": "pi_test_12345"}),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return topup_tx
|
return topup_tx
|
||||||
@@ -99,12 +93,12 @@ async def test_deduct_credits_atomic(server: SpinTestServer):
|
|||||||
|
|
||||||
# Create refund request record (simulating webhook flow)
|
# Create refund request record (simulating webhook flow)
|
||||||
await CreditRefundRequest.prisma().create(
|
await CreditRefundRequest.prisma().create(
|
||||||
data=CreditRefundRequestCreateInput(
|
data={
|
||||||
userId=REFUND_TEST_USER_ID,
|
"userId": REFUND_TEST_USER_ID,
|
||||||
amount=500,
|
"amount": 500,
|
||||||
transactionKey=topup_tx.transactionKey, # Should match the original transaction
|
"transactionKey": topup_tx.transactionKey, # Should match the original transaction
|
||||||
reason="Test refund",
|
"reason": "Test refund",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call deduct_credits
|
# Call deduct_credits
|
||||||
@@ -292,12 +286,12 @@ async def test_concurrent_refunds(server: SpinTestServer):
|
|||||||
refund_requests = []
|
refund_requests = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
req = await CreditRefundRequest.prisma().create(
|
req = await CreditRefundRequest.prisma().create(
|
||||||
data=CreditRefundRequestCreateInput(
|
data={
|
||||||
userId=REFUND_TEST_USER_ID,
|
"userId": REFUND_TEST_USER_ID,
|
||||||
amount=100, # $1 each
|
"amount": 100, # $1 each
|
||||||
transactionKey=topup_tx.transactionKey,
|
"transactionKey": topup_tx.transactionKey,
|
||||||
reason=f"Test refund {i}",
|
"reason": f"Test refund {i}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
refund_requests.append(req)
|
refund_requests.append(req)
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,6 @@ from datetime import datetime, timedelta, timezone
|
|||||||
import pytest
|
import pytest
|
||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.models import CreditTransaction, UserBalance
|
from prisma.models import CreditTransaction, UserBalance
|
||||||
from prisma.types import (
|
|
||||||
CreditTransactionCreateInput,
|
|
||||||
UserBalanceCreateInput,
|
|
||||||
UserBalanceUpsertInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
@@ -28,10 +23,10 @@ async def disable_test_user_transactions():
|
|||||||
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
old_date = datetime.now(timezone.utc) - timedelta(days=35) # More than a month ago
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": DEFAULT_USER_ID},
|
where={"userId": DEFAULT_USER_ID},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=0),
|
"create": {"userId": DEFAULT_USER_ID, "balance": 0},
|
||||||
update={"balance": 0, "updatedAt": old_date},
|
"update": {"balance": 0, "updatedAt": old_date},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -145,23 +140,23 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
|
|
||||||
# Manually create a transaction with month 1 timestamp to establish history
|
# Manually create a transaction with month 1 timestamp to establish history
|
||||||
await CreditTransaction.prisma().create(
|
await CreditTransaction.prisma().create(
|
||||||
data=CreditTransactionCreateInput(
|
data={
|
||||||
userId=DEFAULT_USER_ID,
|
"userId": DEFAULT_USER_ID,
|
||||||
amount=100,
|
"amount": 100,
|
||||||
type=CreditTransactionType.TOP_UP,
|
"type": CreditTransactionType.TOP_UP,
|
||||||
runningBalance=1100,
|
"runningBalance": 1100,
|
||||||
isActive=True,
|
"isActive": True,
|
||||||
createdAt=month1, # Set specific timestamp
|
"createdAt": month1, # Set specific timestamp
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update user balance to match
|
# Update user balance to match
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": DEFAULT_USER_ID},
|
where={"userId": DEFAULT_USER_ID},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=DEFAULT_USER_ID, balance=1100),
|
"create": {"userId": DEFAULT_USER_ID, "balance": 1100},
|
||||||
update={"balance": 1100},
|
"update": {"balance": 1100},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now test month 2 behavior
|
# Now test month 2 behavior
|
||||||
@@ -180,14 +175,14 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
|
|
||||||
# Create a month 2 transaction to update the last transaction time
|
# Create a month 2 transaction to update the last transaction time
|
||||||
await CreditTransaction.prisma().create(
|
await CreditTransaction.prisma().create(
|
||||||
data=CreditTransactionCreateInput(
|
data={
|
||||||
userId=DEFAULT_USER_ID,
|
"userId": DEFAULT_USER_ID,
|
||||||
amount=-700, # Spent 700 to get to 400
|
"amount": -700, # Spent 700 to get to 400
|
||||||
type=CreditTransactionType.USAGE,
|
"type": CreditTransactionType.USAGE,
|
||||||
runningBalance=400,
|
"runningBalance": 400,
|
||||||
isActive=True,
|
"isActive": True,
|
||||||
createdAt=month2,
|
"createdAt": month2,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move to month 3
|
# Move to month 3
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
from prisma.types import UserBalanceCreateInput, UserBalanceUpsertInput, UserCreateInput
|
|
||||||
|
|
||||||
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
from backend.data.credit import POSTGRES_INT_MIN, UserCredit
|
||||||
from backend.util.test import SpinTestServer
|
from backend.util.test import SpinTestServer
|
||||||
@@ -22,11 +21,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for underflow tests."""
|
"""Create a test user for underflow tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=user_id,
|
"id": user_id,
|
||||||
email=f"test-{user_id}@example.com",
|
"email": f"test-{user_id}@example.com",
|
||||||
name=f"Test User {user_id[:8]}",
|
"name": f"Test User {user_id[:8]}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -34,10 +33,7 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={"create": {"userId": user_id, "balance": 0}, "update": {"balance": 0}},
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=0),
|
|
||||||
update={"balance": 0},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -70,14 +66,14 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
|||||||
initial_balance_target = POSTGRES_INT_MIN + 100
|
initial_balance_target = POSTGRES_INT_MIN + 100
|
||||||
|
|
||||||
# Use direct database update to set the balance close to underflow
|
# Use direct database update to set the balance close to underflow
|
||||||
|
from prisma.models import UserBalance
|
||||||
|
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(
|
"create": {"userId": user_id, "balance": initial_balance_target},
|
||||||
userId=user_id, balance=initial_balance_target
|
"update": {"balance": initial_balance_target},
|
||||||
),
|
},
|
||||||
update={"balance": initial_balance_target},
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
current_balance = await credit_system.get_credits(user_id)
|
current_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -114,10 +110,10 @@ async def test_debug_underflow_step_by_step(server: SpinTestServer):
|
|||||||
# Set balance to exactly POSTGRES_INT_MIN
|
# Set balance to exactly POSTGRES_INT_MIN
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=POSTGRES_INT_MIN),
|
"create": {"userId": user_id, "balance": POSTGRES_INT_MIN},
|
||||||
update={"balance": POSTGRES_INT_MIN},
|
"update": {"balance": POSTGRES_INT_MIN},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
edge_balance = await credit_system.get_credits(user_id)
|
edge_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -151,13 +147,15 @@ async def test_underflow_protection_large_refunds(server: SpinTestServer):
|
|||||||
# Set up balance close to underflow threshold to test the protection
|
# Set up balance close to underflow threshold to test the protection
|
||||||
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
# Set balance to POSTGRES_INT_MIN + 1000, then try to subtract 2000
|
||||||
# This should trigger underflow protection
|
# This should trigger underflow protection
|
||||||
|
from prisma.models import UserBalance
|
||||||
|
|
||||||
test_balance = POSTGRES_INT_MIN + 1000
|
test_balance = POSTGRES_INT_MIN + 1000
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=test_balance),
|
"create": {"userId": user_id, "balance": test_balance},
|
||||||
update={"balance": test_balance},
|
"update": {"balance": test_balance},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
current_balance = await credit_system.get_credits(user_id)
|
current_balance = await credit_system.get_credits(user_id)
|
||||||
@@ -214,13 +212,15 @@ async def test_multiple_large_refunds_cumulative_underflow(server: SpinTestServe
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up balance close to underflow threshold
|
# Set up balance close to underflow threshold
|
||||||
|
from prisma.models import UserBalance
|
||||||
|
|
||||||
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
initial_balance = POSTGRES_INT_MIN + 500 # Close to minimum but with some room
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
"create": {"userId": user_id, "balance": initial_balance},
|
||||||
update={"balance": initial_balance},
|
"update": {"balance": initial_balance},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply multiple refunds that would cumulatively underflow
|
# Apply multiple refunds that would cumulatively underflow
|
||||||
@@ -290,13 +290,15 @@ async def test_concurrent_large_refunds_no_underflow(server: SpinTestServer):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Set up balance close to underflow threshold
|
# Set up balance close to underflow threshold
|
||||||
|
from prisma.models import UserBalance
|
||||||
|
|
||||||
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
initial_balance = POSTGRES_INT_MIN + 1000 # Close to minimum
|
||||||
await UserBalance.prisma().upsert(
|
await UserBalance.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserBalanceUpsertInput(
|
data={
|
||||||
create=UserBalanceCreateInput(userId=user_id, balance=initial_balance),
|
"create": {"userId": user_id, "balance": initial_balance},
|
||||||
update={"balance": initial_balance},
|
"update": {"balance": initial_balance},
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def large_refund(amount: int, label: str):
|
async def large_refund(amount: int, label: str):
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import pytest
|
|||||||
from prisma.enums import CreditTransactionType
|
from prisma.enums import CreditTransactionType
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import CreditTransaction, User, UserBalance
|
from prisma.models import CreditTransaction, User, UserBalance
|
||||||
from prisma.types import UserBalanceCreateInput, UserCreateInput
|
|
||||||
|
|
||||||
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
from backend.data.credit import UsageTransactionMetadata, UserCredit
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -25,11 +24,11 @@ async def create_test_user(user_id: str) -> None:
|
|||||||
"""Create a test user for migration tests."""
|
"""Create a test user for migration tests."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().create(
|
await User.prisma().create(
|
||||||
data=UserCreateInput(
|
data={
|
||||||
id=user_id,
|
"id": user_id,
|
||||||
email=f"test-{user_id}@example.com",
|
"email": f"test-{user_id}@example.com",
|
||||||
name=f"Test User {user_id[:8]}",
|
"name": f"Test User {user_id[:8]}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# User already exists, continue
|
# User already exists, continue
|
||||||
@@ -122,7 +121,7 @@ async def test_detect_stale_user_balance_queries(server: SpinTestServer):
|
|||||||
try:
|
try:
|
||||||
# Create UserBalance with specific value
|
# Create UserBalance with specific value
|
||||||
await UserBalance.prisma().create(
|
await UserBalance.prisma().create(
|
||||||
data=UserBalanceCreateInput(userId=user_id, balance=5000) # $50
|
data={"userId": user_id, "balance": 5000} # $50
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
# Verify that get_credits returns UserBalance value (5000), not any stale User.balance value
|
||||||
@@ -161,9 +160,7 @@ async def test_concurrent_operations_use_userbalance_only(server: SpinTestServer
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Set initial balance in UserBalance
|
# Set initial balance in UserBalance
|
||||||
await UserBalance.prisma().create(
|
await UserBalance.prisma().create(data={"userId": user_id, "balance": 1000})
|
||||||
data=UserBalanceCreateInput(userId=user_id, balance=1000)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
# Run concurrent operations to ensure they all use UserBalance atomic operations
|
||||||
async def concurrent_spend(amount: int, label: str):
|
async def concurrent_spend(amount: int, label: str):
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from multiprocessing import Manager
|
|
||||||
from queue import Empty
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
@@ -27,7 +27,6 @@ from prisma.models import (
|
|||||||
AgentNodeExecutionKeyValueData,
|
AgentNodeExecutionKeyValueData,
|
||||||
)
|
)
|
||||||
from prisma.types import (
|
from prisma.types import (
|
||||||
AgentGraphExecutionCreateInput,
|
|
||||||
AgentGraphExecutionUpdateManyMutationInput,
|
AgentGraphExecutionUpdateManyMutationInput,
|
||||||
AgentGraphExecutionWhereInput,
|
AgentGraphExecutionWhereInput,
|
||||||
AgentNodeExecutionCreateInput,
|
AgentNodeExecutionCreateInput,
|
||||||
@@ -35,7 +34,7 @@ from prisma.types import (
|
|||||||
AgentNodeExecutionKeyValueDataCreateInput,
|
AgentNodeExecutionKeyValueDataCreateInput,
|
||||||
AgentNodeExecutionUpdateInput,
|
AgentNodeExecutionUpdateInput,
|
||||||
AgentNodeExecutionWhereInput,
|
AgentNodeExecutionWhereInput,
|
||||||
_AgentNodeExecutionWhereUnique_id_Input,
|
AgentNodeExecutionWhereUniqueInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
from pydantic import BaseModel, ConfigDict, JsonValue, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
@@ -66,19 +65,15 @@ from .includes import (
|
|||||||
)
|
)
|
||||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
||||||
|
|
||||||
class GrantResolverContext(BaseModel):
|
|
||||||
"""Context for grant-based credential resolution in external API executions."""
|
|
||||||
|
|
||||||
client_db_id: str # The OAuth client database UUID
|
|
||||||
grant_ids: list[str] # List of grant IDs to use for credential resolution
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionContext(BaseModel):
|
class ExecutionContext(BaseModel):
|
||||||
"""
|
"""
|
||||||
Unified context that carries execution-level data throughout the entire execution flow.
|
Unified context that carries execution-level data throughout the entire execution flow.
|
||||||
@@ -89,8 +84,6 @@ class ExecutionContext(BaseModel):
|
|||||||
user_timezone: str = "UTC"
|
user_timezone: str = "UTC"
|
||||||
root_execution_id: Optional[str] = None
|
root_execution_id: Optional[str] = None
|
||||||
parent_execution_id: Optional[str] = None
|
parent_execution_id: Optional[str] = None
|
||||||
# For external API executions using credential grants
|
|
||||||
grant_resolver_context: Optional[GrantResolverContext] = None
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- Models -------------------------- #
|
# -------------------------- Models -------------------------- #
|
||||||
@@ -715,18 +708,18 @@ async def create_graph_execution(
|
|||||||
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
|
||||||
"""
|
"""
|
||||||
result = await AgentGraphExecution.prisma().create(
|
result = await AgentGraphExecution.prisma().create(
|
||||||
data=AgentGraphExecutionCreateInput(
|
data={
|
||||||
agentGraphId=graph_id,
|
"agentGraphId": graph_id,
|
||||||
agentGraphVersion=graph_version,
|
"agentGraphVersion": graph_version,
|
||||||
executionStatus=ExecutionStatus.INCOMPLETE,
|
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||||
inputs=SafeJson(inputs),
|
"inputs": SafeJson(inputs),
|
||||||
credentialInputs=(
|
"credentialInputs": (
|
||||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||||
),
|
),
|
||||||
nodesInputMasks=(
|
"nodesInputMasks": (
|
||||||
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
SafeJson(nodes_input_masks) if nodes_input_masks else Json({})
|
||||||
),
|
),
|
||||||
NodeExecutions={
|
"NodeExecutions": {
|
||||||
"create": [
|
"create": [
|
||||||
AgentNodeExecutionCreateInput(
|
AgentNodeExecutionCreateInput(
|
||||||
agentNodeId=node_id,
|
agentNodeId=node_id,
|
||||||
@@ -742,10 +735,10 @@ async def create_graph_execution(
|
|||||||
for node_id, node_input in starting_nodes_input
|
for node_id, node_input in starting_nodes_input
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
userId=user_id,
|
"userId": user_id,
|
||||||
agentPresetId=preset_id,
|
"agentPresetId": preset_id,
|
||||||
parentGraphExecutionId=parent_graph_exec_id,
|
"parentGraphExecutionId": parent_graph_exec_id,
|
||||||
),
|
},
|
||||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -837,15 +830,39 @@ async def upsert_execution_output(
|
|||||||
"""
|
"""
|
||||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||||
"""
|
"""
|
||||||
data = AgentNodeExecutionInputOutputCreateInput(
|
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||||
name=output_name,
|
"name": output_name,
|
||||||
referencedByOutputExecId=node_exec_id,
|
"referencedByOutputExecId": node_exec_id,
|
||||||
)
|
}
|
||||||
if output_data is not None:
|
if output_data is not None:
|
||||||
data["data"] = SafeJson(output_data)
|
data["data"] = SafeJson(output_data)
|
||||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_execution_outputs_by_node_exec_id(
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get all execution outputs for a specific node execution ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: The node execution ID to get outputs for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping output names to their data values
|
||||||
|
"""
|
||||||
|
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
|
||||||
|
where={"referencedByOutputExecId": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for output in outputs:
|
||||||
|
if output.data is not None:
|
||||||
|
result[output.name] = type_utils.convert(output.data, JsonValue)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_execution_start_time(
|
async def update_graph_execution_start_time(
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
) -> GraphExecution | None:
|
) -> GraphExecution | None:
|
||||||
@@ -958,7 +975,7 @@ async def update_node_execution_status(
|
|||||||
|
|
||||||
if res := await AgentNodeExecution.prisma().update(
|
if res := await AgentNodeExecution.prisma().update(
|
||||||
where=cast(
|
where=cast(
|
||||||
_AgentNodeExecutionWhereUnique_id_Input,
|
AgentNodeExecutionWhereUniqueInput,
|
||||||
{
|
{
|
||||||
"id": node_exec_id,
|
"id": node_exec_id,
|
||||||
"executionStatus": {"in": [s.value for s in allowed_from]},
|
"executionStatus": {"in": [s.value for s in allowed_from]},
|
||||||
@@ -1146,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Queue for managing the execution of agents.
|
Thread-safe queue for managing node execution within a single graph execution.
|
||||||
This will be shared between different processes
|
|
||||||
|
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||||
|
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||||
|
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue = Manager().Queue()
|
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||||
|
self.queue: queue.Queue[T] = queue.Queue()
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1166,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except Empty:
|
except queue.Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
"""Tests for ExecutionQueue thread-safety."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_execution_queue_uses_stdlib_queue():
|
||||||
|
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
assert isinstance(q.queue, queue.Queue)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_operations():
|
||||||
|
"""Test add, get, empty, and get_or_none."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
|
||||||
|
assert q.empty() is True
|
||||||
|
assert q.get_or_none() is None
|
||||||
|
|
||||||
|
result = q.add("item1")
|
||||||
|
assert result == "item1"
|
||||||
|
assert q.empty() is False
|
||||||
|
|
||||||
|
item = q.get()
|
||||||
|
assert item == "item1"
|
||||||
|
assert q.empty() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safety():
|
||||||
|
"""Test concurrent access from multiple threads."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
results = []
|
||||||
|
num_items = 100
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
for i in range(num_items):
|
||||||
|
q.add(f"item_{i}")
|
||||||
|
|
||||||
|
def consumer():
|
||||||
|
count = 0
|
||||||
|
while count < num_items:
|
||||||
|
item = q.get_or_none()
|
||||||
|
if item is not None:
|
||||||
|
results.append(item)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
producer_thread = threading.Thread(target=producer)
|
||||||
|
consumer_thread = threading.Thread(target=consumer)
|
||||||
|
|
||||||
|
producer_thread.start()
|
||||||
|
consumer_thread.start()
|
||||||
|
|
||||||
|
producer_thread.join(timeout=5)
|
||||||
|
consumer_thread.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(results) == num_items
|
||||||
@@ -10,11 +10,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from prisma.enums import ReviewStatus
|
from prisma.enums import ReviewStatus
|
||||||
from prisma.models import PendingHumanReview
|
from prisma.models import PendingHumanReview
|
||||||
from prisma.types import (
|
from prisma.types import PendingHumanReviewUpdateInput
|
||||||
PendingHumanReviewCreateInput,
|
|
||||||
PendingHumanReviewUpdateInput,
|
|
||||||
PendingHumanReviewUpsertInput,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.server.v2.executions.review.model import (
|
from backend.server.v2.executions.review.model import (
|
||||||
@@ -70,20 +66,20 @@ async def get_or_create_human_review(
|
|||||||
# Upsert - get existing or create new review
|
# Upsert - get existing or create new review
|
||||||
review = await PendingHumanReview.prisma().upsert(
|
review = await PendingHumanReview.prisma().upsert(
|
||||||
where={"nodeExecId": node_exec_id},
|
where={"nodeExecId": node_exec_id},
|
||||||
data=PendingHumanReviewUpsertInput(
|
data={
|
||||||
create=PendingHumanReviewCreateInput(
|
"create": {
|
||||||
userId=user_id,
|
"userId": user_id,
|
||||||
nodeExecId=node_exec_id,
|
"nodeExecId": node_exec_id,
|
||||||
graphExecId=graph_exec_id,
|
"graphExecId": graph_exec_id,
|
||||||
graphId=graph_id,
|
"graphId": graph_id,
|
||||||
graphVersion=graph_version,
|
"graphVersion": graph_version,
|
||||||
payload=SafeJson(input_data),
|
"payload": SafeJson(input_data),
|
||||||
instructions=message,
|
"instructions": message,
|
||||||
editable=editable,
|
"editable": editable,
|
||||||
status=ReviewStatus.WAITING,
|
"status": ReviewStatus.WAITING,
|
||||||
),
|
},
|
||||||
update={}, # Do nothing on update - keep existing review as is
|
"update": {}, # Do nothing on update - keep existing review as is
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,302 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration scopes mapping.
|
|
||||||
|
|
||||||
Maps AutoGPT's fine-grained integration scopes to provider-specific OAuth scopes.
|
|
||||||
These scopes are used to request granular permissions when connecting integrations
|
|
||||||
through the Credential Broker.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
|
|
||||||
class IntegrationScope(str, Enum):
|
|
||||||
"""
|
|
||||||
Fine-grained integration scopes for credential grants.
|
|
||||||
|
|
||||||
Format: {provider}:{resource}.{permission}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Google scopes
|
|
||||||
GOOGLE_EMAIL_READ = "google:email.read"
|
|
||||||
GOOGLE_GMAIL_READONLY = "google:gmail.readonly"
|
|
||||||
GOOGLE_GMAIL_SEND = "google:gmail.send"
|
|
||||||
GOOGLE_GMAIL_MODIFY = "google:gmail.modify"
|
|
||||||
GOOGLE_DRIVE_READONLY = "google:drive.readonly"
|
|
||||||
GOOGLE_DRIVE_FILE = "google:drive.file"
|
|
||||||
GOOGLE_CALENDAR_READONLY = "google:calendar.readonly"
|
|
||||||
GOOGLE_CALENDAR_EVENTS = "google:calendar.events"
|
|
||||||
GOOGLE_SHEETS_READONLY = "google:sheets.readonly"
|
|
||||||
GOOGLE_SHEETS = "google:sheets"
|
|
||||||
GOOGLE_DOCS_READONLY = "google:docs.readonly"
|
|
||||||
GOOGLE_DOCS = "google:docs"
|
|
||||||
|
|
||||||
# GitHub scopes
|
|
||||||
GITHUB_REPOS_READ = "github:repos.read"
|
|
||||||
GITHUB_REPOS_WRITE = "github:repos.write"
|
|
||||||
GITHUB_ISSUES_READ = "github:issues.read"
|
|
||||||
GITHUB_ISSUES_WRITE = "github:issues.write"
|
|
||||||
GITHUB_USER_READ = "github:user.read"
|
|
||||||
GITHUB_GISTS = "github:gists"
|
|
||||||
GITHUB_NOTIFICATIONS = "github:notifications"
|
|
||||||
|
|
||||||
# Discord scopes
|
|
||||||
DISCORD_IDENTIFY = "discord:identify"
|
|
||||||
DISCORD_EMAIL = "discord:email"
|
|
||||||
DISCORD_GUILDS = "discord:guilds"
|
|
||||||
DISCORD_MESSAGES_READ = "discord:messages.read"
|
|
||||||
|
|
||||||
# Twitter scopes
|
|
||||||
TWITTER_READ = "twitter:read"
|
|
||||||
TWITTER_WRITE = "twitter:write"
|
|
||||||
TWITTER_DM = "twitter:dm"
|
|
||||||
|
|
||||||
# Notion scopes
|
|
||||||
NOTION_READ = "notion:read"
|
|
||||||
NOTION_WRITE = "notion:write"
|
|
||||||
|
|
||||||
# Todoist scopes
|
|
||||||
TODOIST_READ = "todoist:read"
|
|
||||||
TODOIST_WRITE = "todoist:write"
|
|
||||||
|
|
||||||
|
|
||||||
# Scope descriptions for consent UI
|
|
||||||
INTEGRATION_SCOPE_DESCRIPTIONS: dict[str, str] = {
|
|
||||||
# Google
|
|
||||||
IntegrationScope.GOOGLE_EMAIL_READ.value: "Read your email address",
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: "Read your Gmail messages",
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: "Send emails on your behalf",
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: "Read, send, and manage your emails",
|
|
||||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: "View files in your Google Drive",
|
|
||||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: "Create and edit files in Google Drive",
|
|
||||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: "View your calendar",
|
|
||||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: "Create and edit calendar events",
|
|
||||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: "View your spreadsheets",
|
|
||||||
IntegrationScope.GOOGLE_SHEETS.value: "Create and edit spreadsheets",
|
|
||||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: "View your documents",
|
|
||||||
IntegrationScope.GOOGLE_DOCS.value: "Create and edit documents",
|
|
||||||
# GitHub
|
|
||||||
IntegrationScope.GITHUB_REPOS_READ.value: "Read repository information",
|
|
||||||
IntegrationScope.GITHUB_REPOS_WRITE.value: "Create and manage repositories",
|
|
||||||
IntegrationScope.GITHUB_ISSUES_READ.value: "Read issues and pull requests",
|
|
||||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: "Create and manage issues",
|
|
||||||
IntegrationScope.GITHUB_USER_READ.value: "Read your GitHub profile",
|
|
||||||
IntegrationScope.GITHUB_GISTS.value: "Create and manage gists",
|
|
||||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: "Access notifications",
|
|
||||||
# Discord
|
|
||||||
IntegrationScope.DISCORD_IDENTIFY.value: "Access your Discord username",
|
|
||||||
IntegrationScope.DISCORD_EMAIL.value: "Access your Discord email",
|
|
||||||
IntegrationScope.DISCORD_GUILDS.value: "View your server list",
|
|
||||||
IntegrationScope.DISCORD_MESSAGES_READ.value: "Read messages",
|
|
||||||
# Twitter
|
|
||||||
IntegrationScope.TWITTER_READ.value: "Read tweets and profile",
|
|
||||||
IntegrationScope.TWITTER_WRITE.value: "Post tweets on your behalf",
|
|
||||||
IntegrationScope.TWITTER_DM.value: "Send and read direct messages",
|
|
||||||
# Notion
|
|
||||||
IntegrationScope.NOTION_READ.value: "View Notion pages",
|
|
||||||
IntegrationScope.NOTION_WRITE.value: "Create and edit Notion pages",
|
|
||||||
# Todoist
|
|
||||||
IntegrationScope.TODOIST_READ.value: "View your tasks",
|
|
||||||
IntegrationScope.TODOIST_WRITE.value: "Create and manage tasks",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Mapping from integration scopes to provider OAuth scopes
|
|
||||||
INTEGRATION_SCOPE_MAPPING: dict[str, dict[str, list[str]]] = {
|
|
||||||
ProviderName.GOOGLE.value: {
|
|
||||||
IntegrationScope.GOOGLE_EMAIL_READ.value: [
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"openid",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_READONLY.value: [
|
|
||||||
"https://www.googleapis.com/auth/gmail.readonly",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_SEND.value: [
|
|
||||||
"https://www.googleapis.com/auth/gmail.send",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_GMAIL_MODIFY.value: [
|
|
||||||
"https://www.googleapis.com/auth/gmail.modify",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_DRIVE_READONLY.value: [
|
|
||||||
"https://www.googleapis.com/auth/drive.readonly",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_DRIVE_FILE.value: [
|
|
||||||
"https://www.googleapis.com/auth/drive.file",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_CALENDAR_READONLY.value: [
|
|
||||||
"https://www.googleapis.com/auth/calendar.readonly",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_CALENDAR_EVENTS.value: [
|
|
||||||
"https://www.googleapis.com/auth/calendar.events",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_SHEETS_READONLY.value: [
|
|
||||||
"https://www.googleapis.com/auth/spreadsheets.readonly",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_SHEETS.value: [
|
|
||||||
"https://www.googleapis.com/auth/spreadsheets",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_DOCS_READONLY.value: [
|
|
||||||
"https://www.googleapis.com/auth/documents.readonly",
|
|
||||||
],
|
|
||||||
IntegrationScope.GOOGLE_DOCS.value: [
|
|
||||||
"https://www.googleapis.com/auth/documents",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
ProviderName.GITHUB.value: {
|
|
||||||
IntegrationScope.GITHUB_REPOS_READ.value: [
|
|
||||||
"repo:status",
|
|
||||||
"public_repo",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_REPOS_WRITE.value: [
|
|
||||||
"repo",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_ISSUES_READ.value: [
|
|
||||||
"repo:status",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_ISSUES_WRITE.value: [
|
|
||||||
"repo",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_USER_READ.value: [
|
|
||||||
"read:user",
|
|
||||||
"user:email",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_GISTS.value: [
|
|
||||||
"gist",
|
|
||||||
],
|
|
||||||
IntegrationScope.GITHUB_NOTIFICATIONS.value: [
|
|
||||||
"notifications",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
ProviderName.DISCORD.value: {
|
|
||||||
IntegrationScope.DISCORD_IDENTIFY.value: [
|
|
||||||
"identify",
|
|
||||||
],
|
|
||||||
IntegrationScope.DISCORD_EMAIL.value: [
|
|
||||||
"email",
|
|
||||||
],
|
|
||||||
IntegrationScope.DISCORD_GUILDS.value: [
|
|
||||||
"guilds",
|
|
||||||
],
|
|
||||||
IntegrationScope.DISCORD_MESSAGES_READ.value: [
|
|
||||||
"messages.read",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
ProviderName.TWITTER.value: {
|
|
||||||
IntegrationScope.TWITTER_READ.value: [
|
|
||||||
"tweet.read",
|
|
||||||
"users.read",
|
|
||||||
],
|
|
||||||
IntegrationScope.TWITTER_WRITE.value: [
|
|
||||||
"tweet.write",
|
|
||||||
],
|
|
||||||
IntegrationScope.TWITTER_DM.value: [
|
|
||||||
"dm.read",
|
|
||||||
"dm.write",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
ProviderName.NOTION.value: {
|
|
||||||
IntegrationScope.NOTION_READ.value: [], # Notion uses workspace-level access
|
|
||||||
IntegrationScope.NOTION_WRITE.value: [],
|
|
||||||
},
|
|
||||||
ProviderName.TODOIST.value: {
|
|
||||||
IntegrationScope.TODOIST_READ.value: [
|
|
||||||
"data:read",
|
|
||||||
],
|
|
||||||
IntegrationScope.TODOIST_WRITE.value: [
|
|
||||||
"data:read_write",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider_scopes(
|
|
||||||
provider: ProviderName | str, integration_scopes: list[str]
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Convert integration scopes to provider-specific OAuth scopes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: The provider name
|
|
||||||
integration_scopes: List of integration scope strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of provider-specific OAuth scopes
|
|
||||||
"""
|
|
||||||
provider_value = provider.value if isinstance(provider, ProviderName) else provider
|
|
||||||
provider_mapping = INTEGRATION_SCOPE_MAPPING.get(provider_value, {})
|
|
||||||
|
|
||||||
oauth_scopes: set[str] = set()
|
|
||||||
for scope in integration_scopes:
|
|
||||||
if scope in provider_mapping:
|
|
||||||
oauth_scopes.update(provider_mapping[scope])
|
|
||||||
|
|
||||||
return list(oauth_scopes)
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider_for_scope(scope: str) -> Optional[ProviderName]:
|
|
||||||
"""
|
|
||||||
Get the provider for an integration scope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scope: Integration scope string (e.g., "google:gmail.readonly")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProviderName or None if not recognized
|
|
||||||
"""
|
|
||||||
if ":" not in scope:
|
|
||||||
return None
|
|
||||||
|
|
||||||
provider_prefix = scope.split(":")[0]
|
|
||||||
|
|
||||||
# Map prefixes to providers
|
|
||||||
prefix_mapping = {
|
|
||||||
"google": ProviderName.GOOGLE,
|
|
||||||
"github": ProviderName.GITHUB,
|
|
||||||
"discord": ProviderName.DISCORD,
|
|
||||||
"twitter": ProviderName.TWITTER,
|
|
||||||
"notion": ProviderName.NOTION,
|
|
||||||
"todoist": ProviderName.TODOIST,
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefix_mapping.get(provider_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_integration_scopes(scopes: list[str]) -> tuple[bool, list[str]]:
|
|
||||||
"""
|
|
||||||
Validate a list of integration scopes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scopes: List of integration scope strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (valid, invalid_scopes)
|
|
||||||
"""
|
|
||||||
valid_scopes = {s.value for s in IntegrationScope}
|
|
||||||
invalid = [s for s in scopes if s not in valid_scopes]
|
|
||||||
return len(invalid) == 0, invalid
|
|
||||||
|
|
||||||
|
|
||||||
def group_scopes_by_provider(
|
|
||||||
scopes: list[str],
|
|
||||||
) -> dict[ProviderName, list[str]]:
|
|
||||||
"""
|
|
||||||
Group integration scopes by their provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scopes: List of integration scope strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping providers to their scopes
|
|
||||||
"""
|
|
||||||
grouped: dict[ProviderName, list[str]] = {}
|
|
||||||
|
|
||||||
for scope in scopes:
|
|
||||||
provider = get_provider_for_scope(scope)
|
|
||||||
if provider:
|
|
||||||
if provider not in grouped:
|
|
||||||
grouped[provider] = []
|
|
||||||
grouped[provider].append(scope)
|
|
||||||
|
|
||||||
return grouped
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth Audit Logging.
|
|
||||||
|
|
||||||
Logs all OAuth-related operations for security auditing and compliance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthEventType(str, Enum):
|
|
||||||
"""Types of OAuth events to audit."""
|
|
||||||
|
|
||||||
# Client events
|
|
||||||
CLIENT_REGISTERED = "client.registered"
|
|
||||||
CLIENT_UPDATED = "client.updated"
|
|
||||||
CLIENT_DELETED = "client.deleted"
|
|
||||||
CLIENT_SECRET_ROTATED = "client.secret_rotated"
|
|
||||||
CLIENT_SUSPENDED = "client.suspended"
|
|
||||||
CLIENT_ACTIVATED = "client.activated"
|
|
||||||
|
|
||||||
# Authorization events
|
|
||||||
AUTHORIZATION_REQUESTED = "authorization.requested"
|
|
||||||
AUTHORIZATION_GRANTED = "authorization.granted"
|
|
||||||
AUTHORIZATION_DENIED = "authorization.denied"
|
|
||||||
AUTHORIZATION_REVOKED = "authorization.revoked"
|
|
||||||
|
|
||||||
# Token events
|
|
||||||
TOKEN_ISSUED = "token.issued"
|
|
||||||
TOKEN_REFRESHED = "token.refreshed"
|
|
||||||
TOKEN_REVOKED = "token.revoked"
|
|
||||||
TOKEN_EXPIRED = "token.expired"
|
|
||||||
|
|
||||||
# Grant events
|
|
||||||
GRANT_CREATED = "grant.created"
|
|
||||||
GRANT_UPDATED = "grant.updated"
|
|
||||||
GRANT_REVOKED = "grant.revoked"
|
|
||||||
GRANT_USED = "grant.used"
|
|
||||||
|
|
||||||
# Credential events
|
|
||||||
CREDENTIAL_CONNECTED = "credential.connected"
|
|
||||||
CREDENTIAL_DELETED = "credential.deleted"
|
|
||||||
|
|
||||||
# Execution events
|
|
||||||
EXECUTION_STARTED = "execution.started"
|
|
||||||
EXECUTION_COMPLETED = "execution.completed"
|
|
||||||
EXECUTION_FAILED = "execution.failed"
|
|
||||||
EXECUTION_CANCELLED = "execution.cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
async def log_oauth_event(
|
|
||||||
event_type: OAuthEventType,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
grant_id: Optional[str] = None,
|
|
||||||
ip_address: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
details: Optional[dict[str, Any]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Log an OAuth audit event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type: Type of event
|
|
||||||
user_id: User ID involved (if any)
|
|
||||||
client_id: OAuth client ID involved (if any)
|
|
||||||
grant_id: Grant ID involved (if any)
|
|
||||||
ip_address: Client IP address
|
|
||||||
user_agent: Client user agent
|
|
||||||
details: Additional event details
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ID of the created audit log entry
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from prisma import Json
|
|
||||||
|
|
||||||
audit_entry = await prisma.oauthauditlog.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"eventType": event_type.value,
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
"grantId": grant_id,
|
|
||||||
"ipAddress": ip_address,
|
|
||||||
"userAgent": user_agent,
|
|
||||||
"details": Json(details or {}),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"OAuth audit: {event_type.value} - "
|
|
||||||
f"user={user_id}, client={client_id}, grant={grant_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return audit_entry.id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail the operation if audit logging fails
|
|
||||||
logger.error(f"Failed to create OAuth audit log: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
async def get_audit_logs(
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
event_type: Optional[OAuthEventType] = None,
|
|
||||||
start_date: Optional[datetime] = None,
|
|
||||||
end_date: Optional[datetime] = None,
|
|
||||||
limit: int = 100,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> list:
|
|
||||||
"""
|
|
||||||
Query OAuth audit logs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Filter by user ID
|
|
||||||
client_id: Filter by client ID
|
|
||||||
event_type: Filter by event type
|
|
||||||
start_date: Filter by start date
|
|
||||||
end_date: Filter by end date
|
|
||||||
limit: Maximum number of results
|
|
||||||
offset: Offset for pagination
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of audit log entries
|
|
||||||
"""
|
|
||||||
where: dict[str, Any] = {}
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
where["userId"] = user_id
|
|
||||||
if client_id:
|
|
||||||
where["clientId"] = client_id
|
|
||||||
if event_type:
|
|
||||||
where["eventType"] = event_type.value
|
|
||||||
if start_date:
|
|
||||||
where["createdAt"] = {"gte": start_date}
|
|
||||||
if end_date:
|
|
||||||
if "createdAt" in where:
|
|
||||||
where["createdAt"]["lte"] = end_date
|
|
||||||
else:
|
|
||||||
where["createdAt"] = {"lte": end_date}
|
|
||||||
|
|
||||||
return await prisma.oauthauditlog.find_many(
|
|
||||||
where=where if where else None, # type: ignore[arg-type]
|
|
||||||
order={"createdAt": "desc"},
|
|
||||||
take=limit,
|
|
||||||
skip=offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_old_audit_logs(days_to_keep: int = 90) -> int:
|
|
||||||
"""
|
|
||||||
Delete audit logs older than the specified number of days.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
days_to_keep: Number of days of logs to retain
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of logs deleted
|
|
||||||
"""
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
|
||||||
|
|
||||||
result = await prisma.oauthauditlog.delete_many(
|
|
||||||
where={"createdAt": {"lt": cutoff_date}}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Cleaned up {result} OAuth audit logs older than {days_to_keep} days")
|
|
||||||
return result
|
|
||||||
@@ -7,11 +7,7 @@ import prisma
|
|||||||
import pydantic
|
import pydantic
|
||||||
from prisma.enums import OnboardingStep
|
from prisma.enums import OnboardingStep
|
||||||
from prisma.models import UserOnboarding
|
from prisma.models import UserOnboarding
|
||||||
from prisma.types import (
|
from prisma.types import UserOnboardingCreateInput, UserOnboardingUpdateInput
|
||||||
UserOnboardingCreateInput,
|
|
||||||
UserOnboardingUpdateInput,
|
|
||||||
UserOnboardingUpsertInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data.credit import get_user_credit_model
|
from backend.data.credit import get_user_credit_model
|
||||||
@@ -116,10 +112,10 @@ async def update_user_onboarding(user_id: str, data: UserOnboardingUpdate):
|
|||||||
|
|
||||||
return await UserOnboarding.prisma().upsert(
|
return await UserOnboarding.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
data=UserOnboardingUpsertInput(
|
data={
|
||||||
create=UserOnboardingCreateInput(userId=user_id, **update),
|
"create": {"userId": user_id, **update},
|
||||||
update=update,
|
"update": update,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from backend.data.execution import (
|
|||||||
get_block_error_stats,
|
get_block_error_stats,
|
||||||
get_child_graph_executions,
|
get_child_graph_executions,
|
||||||
get_execution_kv_data,
|
get_execution_kv_data,
|
||||||
|
get_execution_outputs_by_node_exec_id,
|
||||||
get_frequently_executed_graphs,
|
get_frequently_executed_graphs,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_graph_executions,
|
get_graph_executions,
|
||||||
@@ -147,6 +148,7 @@ class DatabaseManager(AppService):
|
|||||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||||
upsert_execution_input = _(upsert_execution_input)
|
upsert_execution_input = _(upsert_execution_input)
|
||||||
upsert_execution_output = _(upsert_execution_output)
|
upsert_execution_output = _(upsert_execution_output)
|
||||||
|
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
|
||||||
get_execution_kv_data = _(get_execution_kv_data)
|
get_execution_kv_data = _(get_execution_kv_data)
|
||||||
set_execution_kv_data = _(set_execution_kv_data)
|
set_execution_kv_data = _(set_execution_kv_data)
|
||||||
get_block_error_stats = _(get_block_error_stats)
|
get_block_error_stats = _(get_block_error_stats)
|
||||||
@@ -277,6 +279,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
get_user_integrations = d.get_user_integrations
|
get_user_integrations = d.get_user_integrations
|
||||||
upsert_execution_input = d.upsert_execution_input
|
upsert_execution_input = d.upsert_execution_input
|
||||||
upsert_execution_output = d.upsert_execution_output
|
upsert_execution_output = d.upsert_execution_output
|
||||||
|
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||||
update_graph_execution_stats = d.update_graph_execution_stats
|
update_graph_execution_stats = d.update_graph_execution_stats
|
||||||
update_node_execution_status = d.update_node_execution_status
|
update_node_execution_status = d.update_node_execution_status
|
||||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ from backend.executor.utils import (
|
|||||||
validate_exec,
|
validate_exec,
|
||||||
)
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhook_notifier import get_webhook_notifier
|
|
||||||
from backend.notifications.notifications import queue_notification
|
from backend.notifications.notifications import queue_notification
|
||||||
from backend.server.v2.AutoMod.manager import automod_manager
|
from backend.server.v2.AutoMod.manager import automod_manager
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
@@ -134,9 +133,8 @@ def execute_graph(
|
|||||||
cluster_lock: ClusterLock,
|
cluster_lock: ClusterLock,
|
||||||
):
|
):
|
||||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||||
return _tls.processor.on_graph_execution(
|
processor: ExecutionProcessor = _tls.processor
|
||||||
graph_exec_entry, cancel_event, cluster_lock
|
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -144,8 +142,8 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
async def execute_node(
|
async def execute_node(
|
||||||
node: Node,
|
node: Node,
|
||||||
creds_manager: IntegrationCredentialsManager,
|
|
||||||
data: NodeExecutionEntry,
|
data: NodeExecutionEntry,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
execution_stats: NodeExecutionStats | None = None,
|
execution_stats: NodeExecutionStats | None = None,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
@@ -170,6 +168,7 @@ async def execute_node(
|
|||||||
node_id = data.node_id
|
node_id = data.node_id
|
||||||
node_block = node.block
|
node_block = node.block
|
||||||
execution_context = data.execution_context
|
execution_context = data.execution_context
|
||||||
|
creds_manager = execution_processor.creds_manager
|
||||||
|
|
||||||
log_metadata = LogMetadata(
|
log_metadata = LogMetadata(
|
||||||
logger=_logger,
|
logger=_logger,
|
||||||
@@ -213,6 +212,7 @@ async def execute_node(
|
|||||||
"node_exec_id": node_exec_id,
|
"node_exec_id": node_exec_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": execution_context,
|
"execution_context": execution_context,
|
||||||
|
"execution_processor": execution_processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||||
@@ -222,31 +222,11 @@ async def execute_node(
|
|||||||
creds_locks: list[AsyncRedisLock] = []
|
creds_locks: list[AsyncRedisLock] = []
|
||||||
input_model = cast(type[BlockSchema], node_block.input_schema)
|
input_model = cast(type[BlockSchema], node_block.input_schema)
|
||||||
|
|
||||||
# Check if this is an external API execution using grant-based credential resolution
|
|
||||||
grant_resolver = None
|
|
||||||
if execution_context and execution_context.grant_resolver_context:
|
|
||||||
from backend.integrations.grant_resolver import GrantBasedCredentialResolver
|
|
||||||
|
|
||||||
grant_ctx = execution_context.grant_resolver_context
|
|
||||||
grant_resolver = GrantBasedCredentialResolver(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=grant_ctx.client_db_id,
|
|
||||||
grant_ids=grant_ctx.grant_ids,
|
|
||||||
)
|
|
||||||
await grant_resolver.initialize()
|
|
||||||
|
|
||||||
# Handle regular credentials fields
|
# Handle regular credentials fields
|
||||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||||
credentials_meta = input_type(**input_data[field_name])
|
credentials_meta = input_type(**input_data[field_name])
|
||||||
if grant_resolver:
|
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||||
# External API execution - use grant resolver (no locking needed)
|
creds_locks.append(lock)
|
||||||
credentials = await grant_resolver.resolve_credential(credentials_meta.id)
|
|
||||||
else:
|
|
||||||
# Normal execution - use credentials manager with locking
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, credentials_meta.id
|
|
||||||
)
|
|
||||||
creds_locks.append(lock)
|
|
||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||||
@@ -264,17 +244,10 @@ async def execute_node(
|
|||||||
)
|
)
|
||||||
file_name = field_data.get("name", "selected file")
|
file_name = field_data.get("name", "selected file")
|
||||||
try:
|
try:
|
||||||
if grant_resolver:
|
credentials, lock = await creds_manager.acquire(
|
||||||
# External API execution - use grant resolver
|
user_id, cred_id
|
||||||
credentials = await grant_resolver.resolve_credential(
|
)
|
||||||
cred_id
|
creds_locks.append(lock)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Normal execution - use credentials manager
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, cred_id
|
|
||||||
)
|
|
||||||
creds_locks.append(lock)
|
|
||||||
extra_exec_kwargs[kwarg_name] = credentials
|
extra_exec_kwargs[kwarg_name] = credentials
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Credential was deleted or doesn't exist
|
# Credential was deleted or doesn't exist
|
||||||
@@ -636,8 +609,8 @@ class ExecutionProcessor:
|
|||||||
|
|
||||||
async for output_name, output_data in execute_node(
|
async for output_name, output_data in execute_node(
|
||||||
node=node,
|
node=node,
|
||||||
creds_manager=self.creds_manager,
|
|
||||||
data=node_exec,
|
data=node_exec,
|
||||||
|
execution_processor=self,
|
||||||
execution_stats=stats,
|
execution_stats=stats,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
):
|
):
|
||||||
@@ -813,7 +786,6 @@ class ExecutionProcessor:
|
|||||||
graph_exec_id=graph_exec.graph_exec_id,
|
graph_exec_id=graph_exec.graph_exec_id,
|
||||||
status=exec_meta.status,
|
status=exec_meta.status,
|
||||||
stats=exec_stats,
|
stats=exec_stats,
|
||||||
event_loop=self.node_execution_loop,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _charge_usage(
|
def _charge_usage(
|
||||||
@@ -889,12 +861,17 @@ class ExecutionProcessor:
|
|||||||
execution_stats_lock = threading.Lock()
|
execution_stats_lock = threading.Lock()
|
||||||
|
|
||||||
# State holders ----------------------------------------------------
|
# State holders ----------------------------------------------------
|
||||||
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
||||||
NodeExecutionProgress
|
NodeExecutionProgress
|
||||||
)
|
)
|
||||||
running_node_evaluation: dict[str, Future] = {}
|
self.running_node_evaluation: dict[str, Future] = {}
|
||||||
|
self.execution_stats = execution_stats
|
||||||
|
self.execution_stats_lock = execution_stats_lock
|
||||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||||
|
|
||||||
|
running_node_execution = self.running_node_execution
|
||||||
|
running_node_evaluation = self.running_node_evaluation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||||
raise InsufficientBalanceError(
|
raise InsufficientBalanceError(
|
||||||
@@ -1945,53 +1922,6 @@ def update_node_execution_status(
|
|||||||
return exec_update
|
return exec_update
|
||||||
|
|
||||||
|
|
||||||
async def _notify_execution_webhook(
|
|
||||||
execution_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
status: ExecutionStatus,
|
|
||||||
outputs: dict[str, Any] | None = None,
|
|
||||||
error: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Send webhook notification for execution completion if registered.
|
|
||||||
|
|
||||||
This is a fire-and-forget operation that checks if a webhook was registered
|
|
||||||
for this execution and sends the appropriate notification.
|
|
||||||
"""
|
|
||||||
from backend.data.db import prisma
|
|
||||||
|
|
||||||
try:
|
|
||||||
webhook = await prisma.executionwebhook.find_first(
|
|
||||||
where={"executionId": execution_id}
|
|
||||||
)
|
|
||||||
if not webhook:
|
|
||||||
return
|
|
||||||
|
|
||||||
notifier = get_webhook_notifier()
|
|
||||||
|
|
||||||
if status == ExecutionStatus.COMPLETED:
|
|
||||||
await notifier.notify_execution_completed(
|
|
||||||
execution_id=execution_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
client_id=webhook.clientId,
|
|
||||||
webhook_url=webhook.webhookUrl,
|
|
||||||
outputs=outputs or {},
|
|
||||||
webhook_secret=webhook.secret,
|
|
||||||
)
|
|
||||||
elif status == ExecutionStatus.FAILED:
|
|
||||||
await notifier.notify_execution_failed(
|
|
||||||
execution_id=execution_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
client_id=webhook.clientId,
|
|
||||||
webhook_url=webhook.webhookUrl,
|
|
||||||
error=error or "Execution failed",
|
|
||||||
webhook_secret=webhook.secret,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Don't let webhook failures affect execution state updates
|
|
||||||
logger.warning(f"Failed to send webhook notification for {execution_id}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def async_update_graph_execution_state(
|
async def async_update_graph_execution_state(
|
||||||
db_client: "DatabaseManagerAsyncClient",
|
db_client: "DatabaseManagerAsyncClient",
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
@@ -2004,17 +1934,6 @@ async def async_update_graph_execution_state(
|
|||||||
)
|
)
|
||||||
if graph_update:
|
if graph_update:
|
||||||
await send_async_execution_update(graph_update)
|
await send_async_execution_update(graph_update)
|
||||||
|
|
||||||
# Send webhook notification for terminal states
|
|
||||||
if status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED:
|
|
||||||
await _notify_execution_webhook(
|
|
||||||
execution_id=graph_exec_id,
|
|
||||||
agent_id=graph_update.graph_id,
|
|
||||||
status=status,
|
|
||||||
outputs=(
|
|
||||||
graph_update.outputs if hasattr(graph_update, "outputs") else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||||
return graph_update
|
return graph_update
|
||||||
@@ -2025,33 +1944,11 @@ def update_graph_execution_state(
|
|||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
status: ExecutionStatus | None = None,
|
status: ExecutionStatus | None = None,
|
||||||
stats: GraphExecutionStats | None = None,
|
stats: GraphExecutionStats | None = None,
|
||||||
event_loop: asyncio.AbstractEventLoop | None = None,
|
|
||||||
) -> GraphExecution | None:
|
) -> GraphExecution | None:
|
||||||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||||||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||||||
if graph_update:
|
if graph_update:
|
||||||
send_execution_update(graph_update)
|
send_execution_update(graph_update)
|
||||||
|
|
||||||
# Send webhook notification for terminal states (fire-and-forget)
|
|
||||||
if (
|
|
||||||
status == ExecutionStatus.COMPLETED or status == ExecutionStatus.FAILED
|
|
||||||
) and event_loop:
|
|
||||||
try:
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
|
||||||
_notify_execution_webhook(
|
|
||||||
execution_id=graph_exec_id,
|
|
||||||
agent_id=graph_update.graph_id,
|
|
||||||
status=status,
|
|
||||||
outputs=(
|
|
||||||
graph_update.outputs
|
|
||||||
if hasattr(graph_update, "outputs")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
),
|
|
||||||
event_loop,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to schedule webhook notification: {e}")
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||||||
return graph_update
|
return graph_update
|
||||||
|
|||||||
@@ -1,278 +0,0 @@
|
|||||||
"""
|
|
||||||
Grant-Based Credential Resolver.
|
|
||||||
|
|
||||||
Resolves credentials during agent execution based on credential grants.
|
|
||||||
External applications can only use credentials they have been granted access to,
|
|
||||||
and only for the scopes that were granted.
|
|
||||||
|
|
||||||
Credentials are NEVER exposed to external applications - this resolver
|
|
||||||
provides the credentials to the execution engine internally.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from prisma.enums import CredentialGrantPermission
|
|
||||||
from prisma.models import CredentialGrant
|
|
||||||
|
|
||||||
from backend.data import credential_grants as grants_db
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.data.model import Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GrantValidationError(Exception):
|
|
||||||
"""Raised when a grant is invalid or lacks required permissions."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CredentialNotFoundError(Exception):
|
|
||||||
"""Raised when a credential referenced by a grant is not found."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ScopeMismatchError(Exception):
|
|
||||||
"""Raised when the grant doesn't cover required scopes."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GrantBasedCredentialResolver:
|
|
||||||
"""
|
|
||||||
Resolves credentials for agent execution based on credential grants.
|
|
||||||
|
|
||||||
This resolver validates that:
|
|
||||||
1. The grant exists and is valid (not revoked/expired)
|
|
||||||
2. The grant has USE permission
|
|
||||||
3. The grant covers the required scopes (if specified)
|
|
||||||
4. The underlying credential exists
|
|
||||||
|
|
||||||
Then it provides the credential to the execution engine internally.
|
|
||||||
The credential value is NEVER exposed to external applications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
grant_ids: list[str],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the resolver.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID who owns the credentials
|
|
||||||
client_id: Database ID of the OAuth client
|
|
||||||
grant_ids: List of grant IDs the client is using for this execution
|
|
||||||
"""
|
|
||||||
self.user_id = user_id
|
|
||||||
self.client_id = client_id
|
|
||||||
self.grant_ids = grant_ids
|
|
||||||
self._grants: dict[str, CredentialGrant] = {}
|
|
||||||
self._credentials_manager = IntegrationCredentialsManager()
|
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
"""
|
|
||||||
Load and validate all grants.
|
|
||||||
|
|
||||||
This should be called before any credential resolution.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
GrantValidationError: If any grant is invalid
|
|
||||||
"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
for grant_id in self.grant_ids:
|
|
||||||
grant = await grants_db.get_credential_grant(
|
|
||||||
grant_id=grant_id,
|
|
||||||
user_id=self.user_id,
|
|
||||||
client_id=self.client_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not grant:
|
|
||||||
raise GrantValidationError(f"Grant {grant_id} not found")
|
|
||||||
|
|
||||||
# Check if revoked
|
|
||||||
if grant.revokedAt:
|
|
||||||
raise GrantValidationError(f"Grant {grant_id} has been revoked")
|
|
||||||
|
|
||||||
# Check if expired
|
|
||||||
if grant.expiresAt and grant.expiresAt < now:
|
|
||||||
raise GrantValidationError(f"Grant {grant_id} has expired")
|
|
||||||
|
|
||||||
# Check USE permission
|
|
||||||
if CredentialGrantPermission.USE not in grant.permissions:
|
|
||||||
raise GrantValidationError(
|
|
||||||
f"Grant {grant_id} does not have USE permission"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._grants[grant_id] = grant
|
|
||||||
|
|
||||||
self._initialized = True
|
|
||||||
logger.info(
|
|
||||||
f"Initialized grant resolver with {len(self._grants)} grants "
|
|
||||||
f"for user {self.user_id}, client {self.client_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def resolve_credential(
|
|
||||||
self,
|
|
||||||
credential_id: str,
|
|
||||||
required_scopes: Optional[list[str]] = None,
|
|
||||||
) -> Credentials:
|
|
||||||
"""
|
|
||||||
Resolve a credential for agent execution.
|
|
||||||
|
|
||||||
This method:
|
|
||||||
1. Finds a grant that covers this credential
|
|
||||||
2. Validates the grant covers required scopes
|
|
||||||
3. Retrieves the actual credential
|
|
||||||
4. Updates grant usage tracking
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credential_id: ID of the credential to resolve
|
|
||||||
required_scopes: Optional list of scopes the credential must have
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The resolved Credentials object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
GrantValidationError: If no valid grant covers this credential
|
|
||||||
ScopeMismatchError: If the grant doesn't cover required scopes
|
|
||||||
CredentialNotFoundError: If the underlying credential doesn't exist
|
|
||||||
"""
|
|
||||||
if not self._initialized:
|
|
||||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
|
||||||
|
|
||||||
# Find a grant that covers this credential
|
|
||||||
matching_grant: Optional[CredentialGrant] = None
|
|
||||||
for grant in self._grants.values():
|
|
||||||
if grant.credentialId == credential_id:
|
|
||||||
matching_grant = grant
|
|
||||||
break
|
|
||||||
|
|
||||||
if not matching_grant:
|
|
||||||
raise GrantValidationError(f"No grant found for credential {credential_id}")
|
|
||||||
|
|
||||||
# Validate scopes if required
|
|
||||||
if required_scopes:
|
|
||||||
granted_scopes = set(matching_grant.grantedScopes)
|
|
||||||
required_scopes_set = set(required_scopes)
|
|
||||||
|
|
||||||
missing_scopes = required_scopes_set - granted_scopes
|
|
||||||
if missing_scopes:
|
|
||||||
raise ScopeMismatchError(
|
|
||||||
f"Grant {matching_grant.id} is missing required scopes: "
|
|
||||||
f"{', '.join(missing_scopes)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the actual credential
|
|
||||||
credentials = await self._credentials_manager.get(
|
|
||||||
user_id=self.user_id,
|
|
||||||
credentials_id=credential_id,
|
|
||||||
lock=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not credentials:
|
|
||||||
raise CredentialNotFoundError(
|
|
||||||
f"Credential {credential_id} not found for user {self.user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update last used timestamp for the grant
|
|
||||||
await grants_db.update_grant_last_used(matching_grant.id)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Resolved credential {credential_id} via grant {matching_grant.id} "
|
|
||||||
f"for client {self.client_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return credentials
|
|
||||||
|
|
||||||
async def get_available_credentials(self) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Get list of available credentials based on grants.
|
|
||||||
|
|
||||||
Returns a list of credential metadata (NOT the actual credential values).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with credential metadata
|
|
||||||
"""
|
|
||||||
if not self._initialized:
|
|
||||||
raise RuntimeError("Resolver not initialized. Call initialize() first.")
|
|
||||||
|
|
||||||
credentials_info = []
|
|
||||||
for grant in self._grants.values():
|
|
||||||
credentials_info.append(
|
|
||||||
{
|
|
||||||
"grant_id": grant.id,
|
|
||||||
"credential_id": grant.credentialId,
|
|
||||||
"provider": grant.provider,
|
|
||||||
"granted_scopes": grant.grantedScopes,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return credentials_info
|
|
||||||
|
|
||||||
def get_grant_for_credential(self, credential_id: str) -> Optional[CredentialGrant]:
|
|
||||||
"""
|
|
||||||
Get the grant for a specific credential.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credential_id: ID of the credential
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CredentialGrant or None if not found
|
|
||||||
"""
|
|
||||||
for grant in self._grants.values():
|
|
||||||
if grant.credentialId == credential_id:
|
|
||||||
return grant
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def create_resolver_from_oauth_token(
|
|
||||||
user_id: str,
|
|
||||||
client_public_id: str,
|
|
||||||
grant_ids: Optional[list[str]] = None,
|
|
||||||
) -> GrantBasedCredentialResolver:
|
|
||||||
"""
|
|
||||||
Create a credential resolver from OAuth token context.
|
|
||||||
|
|
||||||
This is a convenience function for creating a resolver from
|
|
||||||
the context available in OAuth-authenticated requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID from the OAuth token
|
|
||||||
client_public_id: Public client ID from the OAuth token
|
|
||||||
grant_ids: Optional list of grant IDs to use
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Initialized GrantBasedCredentialResolver
|
|
||||||
"""
|
|
||||||
# Look up the OAuth client database ID from the public client ID
|
|
||||||
client = await prisma.oauthclient.find_unique(where={"clientId": client_public_id})
|
|
||||||
if not client:
|
|
||||||
raise GrantValidationError(f"OAuth client {client_public_id} not found")
|
|
||||||
|
|
||||||
# If no grant IDs specified, get all grants for this client+user
|
|
||||||
if grant_ids is None:
|
|
||||||
grants = await grants_db.get_grants_for_user_client(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client.id,
|
|
||||||
include_revoked=False,
|
|
||||||
include_expired=False,
|
|
||||||
)
|
|
||||||
grant_ids = [g.id for g in grants]
|
|
||||||
|
|
||||||
resolver = GrantBasedCredentialResolver(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client.id,
|
|
||||||
grant_ids=grant_ids,
|
|
||||||
)
|
|
||||||
await resolver.initialize()
|
|
||||||
|
|
||||||
return resolver
|
|
||||||
@@ -1,331 +0,0 @@
|
|||||||
"""
|
|
||||||
Webhook Notification System for External API.
|
|
||||||
|
|
||||||
Sends webhook notifications to external applications for execution events.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import weakref
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Coroutine, Optional
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Webhook delivery settings
|
|
||||||
WEBHOOK_TIMEOUT_SECONDS = 30
|
|
||||||
WEBHOOK_MAX_RETRIES = 3
|
|
||||||
WEBHOOK_RETRY_DELAYS = [5, 30, 300] # seconds: 5s, 30s, 5min
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookDeliveryError(Exception):
|
|
||||||
"""Raised when webhook delivery fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def sign_webhook_payload(payload: dict[str, Any], secret: str) -> str:
|
|
||||||
"""
|
|
||||||
Create HMAC-SHA256 signature for webhook payload.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: The webhook payload to sign
|
|
||||||
secret: The webhook secret key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Hex-encoded HMAC-SHA256 signature
|
|
||||||
"""
|
|
||||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode()
|
|
||||||
signature = hmac.new(
|
|
||||||
secret.encode(),
|
|
||||||
payload_bytes,
|
|
||||||
hashlib.sha256,
|
|
||||||
).hexdigest()
|
|
||||||
return signature
|
|
||||||
|
|
||||||
|
|
||||||
def verify_webhook_signature(
|
|
||||||
payload: dict[str, Any],
|
|
||||||
signature: str,
|
|
||||||
secret: str,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Verify a webhook signature.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: The webhook payload
|
|
||||||
signature: The signature to verify
|
|
||||||
secret: The webhook secret key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if signature is valid
|
|
||||||
"""
|
|
||||||
expected = sign_webhook_payload(payload, secret)
|
|
||||||
return hmac.compare_digest(expected, signature)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_webhook_url(url: str, allowed_domains: list[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that a webhook URL is allowed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The webhook URL to validate
|
|
||||||
allowed_domains: List of allowed domains (from OAuth client config)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if URL is valid and allowed
|
|
||||||
"""
|
|
||||||
from backend.util.url import hostname_matches_any_domain
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed = urlparse(url)
|
|
||||||
|
|
||||||
# Must be HTTPS (except for localhost in development)
|
|
||||||
if parsed.scheme != "https":
|
|
||||||
if not (
|
|
||||||
parsed.scheme == "http"
|
|
||||||
and parsed.hostname in ["localhost", "127.0.0.1"]
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Must have a host
|
|
||||||
if not parsed.hostname:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check against allowed domains
|
|
||||||
return hostname_matches_any_domain(parsed.hostname, allowed_domains)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def send_webhook(
|
|
||||||
url: str,
|
|
||||||
payload: dict[str, Any],
|
|
||||||
secret: Optional[str] = None,
|
|
||||||
timeout: int = WEBHOOK_TIMEOUT_SECONDS,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Send a webhook notification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: Webhook URL
|
|
||||||
payload: Payload to send
|
|
||||||
secret: Optional secret for signature
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if webhook was delivered successfully
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"User-Agent": "AutoGPT-Webhook/1.0",
|
|
||||||
"X-Webhook-Timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if secret:
|
|
||||||
signature = sign_webhook_payload(payload, secret)
|
|
||||||
headers["X-Webhook-Signature"] = f"sha256={signature}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
response = await client.post(
|
|
||||||
url,
|
|
||||||
json=payload,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code >= 200 and response.status_code < 300:
|
|
||||||
logger.debug(f"Webhook delivered successfully to {url}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Webhook delivery failed: {url} returned {response.status_code}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
logger.warning(f"Webhook delivery timed out: {url}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Webhook delivery error: {url} - {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def send_webhook_with_retry(
|
|
||||||
url: str,
|
|
||||||
payload: dict[str, Any],
|
|
||||||
secret: Optional[str] = None,
|
|
||||||
max_retries: int = WEBHOOK_MAX_RETRIES,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Send a webhook with automatic retries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: Webhook URL
|
|
||||||
payload: Payload to send
|
|
||||||
secret: Optional secret for signature
|
|
||||||
max_retries: Maximum number of retry attempts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if webhook was eventually delivered successfully
|
|
||||||
"""
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
if await send_webhook(url, payload, secret):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if attempt < max_retries:
|
|
||||||
delay = WEBHOOK_RETRY_DELAYS[min(attempt, len(WEBHOOK_RETRY_DELAYS) - 1)]
|
|
||||||
logger.info(
|
|
||||||
f"Webhook delivery failed, retrying in {delay}s (attempt {attempt + 1})"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error(f"Webhook delivery failed after {max_retries} retries: {url}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Track pending webhook tasks to prevent garbage collection
|
|
||||||
# Using WeakSet so tasks are automatically removed when they complete and are dereferenced
|
|
||||||
_pending_webhook_tasks: weakref.WeakSet[asyncio.Task[Any]] = weakref.WeakSet()
|
|
||||||
|
|
||||||
|
|
||||||
def _create_tracked_task(coro: Coroutine[Any, Any, bool]) -> asyncio.Task[bool]:
|
|
||||||
"""Create a task that is tracked to prevent garbage collection."""
|
|
||||||
task = asyncio.create_task(coro)
|
|
||||||
_pending_webhook_tasks.add(task)
|
|
||||||
# No explicit done callback needed - WeakSet automatically removes
|
|
||||||
# references when tasks are garbage collected after completion
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookNotifier:
|
|
||||||
"""
|
|
||||||
Service for sending webhook notifications to external applications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def notify_execution_started(
|
|
||||||
self,
|
|
||||||
execution_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
client_id: str,
|
|
||||||
webhook_url: str,
|
|
||||||
webhook_secret: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Notify external app that an execution has started.
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"event": "execution.started",
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"data": {
|
|
||||||
"execution_id": execution_id,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"status": "running",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_create_tracked_task(
|
|
||||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def notify_execution_completed(
|
|
||||||
self,
|
|
||||||
execution_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
client_id: str,
|
|
||||||
webhook_url: str,
|
|
||||||
outputs: dict[str, Any],
|
|
||||||
webhook_secret: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Notify external app that an execution has completed successfully.
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"event": "execution.completed",
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"data": {
|
|
||||||
"execution_id": execution_id,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"status": "completed",
|
|
||||||
"outputs": outputs,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_create_tracked_task(
|
|
||||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def notify_execution_failed(
|
|
||||||
self,
|
|
||||||
execution_id: str,
|
|
||||||
agent_id: str,
|
|
||||||
client_id: str,
|
|
||||||
webhook_url: str,
|
|
||||||
error: str,
|
|
||||||
webhook_secret: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Notify external app that an execution has failed.
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"event": "execution.failed",
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"data": {
|
|
||||||
"execution_id": execution_id,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"status": "failed",
|
|
||||||
"error": error,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_create_tracked_task(
|
|
||||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def notify_grant_revoked(
|
|
||||||
self,
|
|
||||||
grant_id: str,
|
|
||||||
credential_id: str,
|
|
||||||
provider: str,
|
|
||||||
client_id: str,
|
|
||||||
webhook_url: str,
|
|
||||||
webhook_secret: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Notify external app that a credential grant has been revoked.
|
|
||||||
"""
|
|
||||||
payload = {
|
|
||||||
"event": "grant.revoked",
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"data": {
|
|
||||||
"grant_id": grant_id,
|
|
||||||
"credential_id": credential_id,
|
|
||||||
"provider": provider,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_create_tracked_task(
|
|
||||||
send_webhook_with_retry(webhook_url, payload, webhook_secret)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
_webhook_notifier: Optional[WebhookNotifier] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_webhook_notifier() -> WebhookNotifier:
|
|
||||||
"""Get the singleton webhook notifier instance."""
|
|
||||||
global _webhook_notifier
|
|
||||||
if _webhook_notifier is None:
|
|
||||||
_webhook_notifier = WebhookNotifier()
|
|
||||||
return _webhook_notifier
|
|
||||||
@@ -3,19 +3,21 @@ from fastapi import FastAPI
|
|||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||||
|
|
||||||
from .routes.execution import execution_router
|
from .routes.integrations import integrations_router
|
||||||
from .routes.grants import grants_router
|
from .routes.tools import tools_router
|
||||||
|
from .routes.v1 import v1_router
|
||||||
|
|
||||||
external_app = FastAPI(
|
external_app = FastAPI(
|
||||||
title="AutoGPT External API",
|
title="AutoGPT External API",
|
||||||
description="External API for AutoGPT integrations (OAuth-based)",
|
description="External API for AutoGPT integrations",
|
||||||
docs_url="/docs",
|
docs_url="/docs",
|
||||||
version="1.0",
|
version="1.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||||
external_app.include_router(grants_router, prefix="/v1")
|
external_app.include_router(v1_router, prefix="/v1")
|
||||||
external_app.include_router(execution_router, prefix="/v1")
|
external_app.include_router(tools_router, prefix="/v1")
|
||||||
|
external_app.include_router(integrations_router, prefix="/v1")
|
||||||
|
|
||||||
# Add Prometheus instrumentation
|
# Add Prometheus instrumentation
|
||||||
instrument_fastapi(
|
instrument_fastapi(
|
||||||
|
|||||||
36
autogpt_platform/backend/backend/server/external/middleware.py
vendored
Normal file
36
autogpt_platform/backend/backend/server/external/middleware.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from fastapi import HTTPException, Security
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
|
||||||
|
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||||
|
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||||
|
"""Base middleware for API key authentication"""
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing API key")
|
||||||
|
|
||||||
|
api_key_obj = await validate_api_key(api_key)
|
||||||
|
|
||||||
|
if not api_key_obj:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
return api_key_obj
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(permission: APIKeyPermission):
|
||||||
|
"""Dependency function for checking specific permissions"""
|
||||||
|
|
||||||
|
async def check_permission(
|
||||||
|
api_key: APIKeyInfo = Security(require_api_key),
|
||||||
|
) -> APIKeyInfo:
|
||||||
|
if not has_permission(api_key, permission):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"API key lacks the required permission '{permission}'",
|
||||||
|
)
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
return check_permission
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth Access Token middleware for external API.
|
|
||||||
|
|
||||||
Validates OAuth access tokens and provides user/client context
|
|
||||||
for external API endpoints that use OAuth authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from fastapi import HTTPException, Security
|
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.server.oauth.token_service import get_token_service
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthTokenInfo(BaseModel):
|
|
||||||
"""Information extracted from a validated OAuth access token."""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
client_id: str
|
|
||||||
scopes: list[str]
|
|
||||||
token_id: str
|
|
||||||
|
|
||||||
|
|
||||||
# HTTP Bearer token extractor
|
|
||||||
oauth_bearer = HTTPBearer(auto_error=False)
|
|
||||||
|
|
||||||
|
|
||||||
async def require_oauth_token(
|
|
||||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(oauth_bearer),
|
|
||||||
) -> OAuthTokenInfo:
|
|
||||||
"""
|
|
||||||
Validate an OAuth access token and return token info.
|
|
||||||
|
|
||||||
Extracts the Bearer token from the Authorization header,
|
|
||||||
validates the JWT signature and claims, and checks that
|
|
||||||
the token hasn't been revoked.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 401 if token is missing, invalid, or revoked
|
|
||||||
"""
|
|
||||||
if credentials is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Missing authorization token",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
token = credentials.credentials
|
|
||||||
token_service = get_token_service()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Verify JWT signature and claims
|
|
||||||
claims = token_service.verify_access_token(token)
|
|
||||||
|
|
||||||
# Check if token is in database and not revoked
|
|
||||||
token_hash = token_service.hash_token(token)
|
|
||||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
|
||||||
where={"tokenHash": token_hash}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not stored_token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token not found",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if stored_token.revokedAt:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token has been revoked",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token has expired",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update last used timestamp (fire and forget)
|
|
||||||
await prisma.oauthaccesstoken.update(
|
|
||||||
where={"id": stored_token.id},
|
|
||||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
return OAuthTokenInfo(
|
|
||||||
user_id=claims.sub,
|
|
||||||
client_id=claims.client_id,
|
|
||||||
scopes=claims.scope.split() if claims.scope else [],
|
|
||||||
token_id=stored_token.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token has expired",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
except jwt.InvalidTokenError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=f"Invalid token: {str(e)}",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def require_scope(required_scope: str):
|
|
||||||
"""
|
|
||||||
Dependency that validates OAuth token and checks for required scope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
required_scope: The scope required for this endpoint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dependency function that returns OAuthTokenInfo if authorized
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def check_scope(
|
|
||||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
|
||||||
) -> OAuthTokenInfo:
|
|
||||||
if required_scope not in token.scopes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"Token lacks required scope '{required_scope}'",
|
|
||||||
headers={"WWW-Authenticate": f'Bearer scope="{required_scope}"'},
|
|
||||||
)
|
|
||||||
return token
|
|
||||||
|
|
||||||
return check_scope
|
|
||||||
|
|
||||||
|
|
||||||
def require_any_scope(*required_scopes: str):
|
|
||||||
"""
|
|
||||||
Dependency that validates OAuth token and checks for any of the required scopes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
required_scopes: At least one of these scopes is required
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dependency function that returns OAuthTokenInfo if authorized
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def check_scopes(
|
|
||||||
token: OAuthTokenInfo = Security(require_oauth_token),
|
|
||||||
) -> OAuthTokenInfo:
|
|
||||||
for scope in required_scopes:
|
|
||||||
if scope in token.scopes:
|
|
||||||
return token
|
|
||||||
|
|
||||||
scope_list = " ".join(required_scopes)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"Token lacks required scopes (need one of: {scope_list})",
|
|
||||||
headers={"WWW-Authenticate": f'Bearer scope="{scope_list}"'},
|
|
||||||
)
|
|
||||||
|
|
||||||
return check_scopes
|
|
||||||
@@ -1,377 +0,0 @@
|
|||||||
"""
|
|
||||||
Agent Execution endpoints for external OAuth clients.
|
|
||||||
|
|
||||||
Allows external applications to:
|
|
||||||
- Execute agents using granted credentials
|
|
||||||
- Poll execution status
|
|
||||||
- Cancel running executions
|
|
||||||
- Get available capabilities
|
|
||||||
|
|
||||||
External apps can only use credentials they have been granted access to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Security
|
|
||||||
from prisma.enums import AgentExecutionStatus
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.data import execution as execution_db
|
|
||||||
from backend.data import graph as graph_db
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.data.execution import ExecutionContext, GrantResolverContext
|
|
||||||
from backend.executor.utils import add_graph_execution
|
|
||||||
from backend.integrations.grant_resolver import (
|
|
||||||
GrantValidationError,
|
|
||||||
create_resolver_from_oauth_token,
|
|
||||||
)
|
|
||||||
from backend.integrations.webhook_notifier import validate_webhook_url
|
|
||||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
execution_router = APIRouter(prefix="/executions", tags=["executions"])
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Request/Response Models
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class ExecuteAgentRequest(BaseModel):
|
|
||||||
"""Request to execute an agent."""
|
|
||||||
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Input values for the agent",
|
|
||||||
)
|
|
||||||
grant_ids: Optional[list[str]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Specific grant IDs to use. If not provided, uses all available grants.",
|
|
||||||
)
|
|
||||||
webhook_url: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="URL to receive execution status webhooks",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExecuteAgentResponse(BaseModel):
|
|
||||||
"""Response from starting an agent execution."""
|
|
||||||
|
|
||||||
execution_id: str
|
|
||||||
status: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionStatusResponse(BaseModel):
|
|
||||||
"""Response with execution status."""
|
|
||||||
|
|
||||||
execution_id: str
|
|
||||||
status: str
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
outputs: Optional[dict[str, Any]] = None
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GrantInfo(BaseModel):
|
|
||||||
"""Summary of a credential grant for capabilities."""
|
|
||||||
|
|
||||||
grant_id: str
|
|
||||||
provider: str
|
|
||||||
scopes: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class CapabilitiesResponse(BaseModel):
|
|
||||||
"""Response describing what the client can do."""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
client_id: str
|
|
||||||
grants: list[GrantInfo]
|
|
||||||
available_scopes: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Endpoints
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@execution_router.get("/capabilities", response_model=CapabilitiesResponse)
|
|
||||||
async def get_capabilities(
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
|
||||||
) -> CapabilitiesResponse:
|
|
||||||
"""
|
|
||||||
Get the capabilities available to this client for the authenticated user.
|
|
||||||
|
|
||||||
Returns information about:
|
|
||||||
- Available credential grants (NOT credential values)
|
|
||||||
- Scopes the client has access to
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
resolver = await create_resolver_from_oauth_token(
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_public_id=token.client_id,
|
|
||||||
)
|
|
||||||
credentials_info = await resolver.get_available_credentials()
|
|
||||||
|
|
||||||
grants = [
|
|
||||||
GrantInfo(
|
|
||||||
grant_id=info["grant_id"],
|
|
||||||
provider=info["provider"],
|
|
||||||
scopes=info["granted_scopes"],
|
|
||||||
)
|
|
||||||
for info in credentials_info
|
|
||||||
]
|
|
||||||
|
|
||||||
return CapabilitiesResponse(
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_id=token.client_id,
|
|
||||||
grants=grants,
|
|
||||||
available_scopes=token.scopes,
|
|
||||||
)
|
|
||||||
except GrantValidationError:
|
|
||||||
# No grants available is not an error, just empty capabilities
|
|
||||||
return CapabilitiesResponse(
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_id=token.client_id,
|
|
||||||
grants=[],
|
|
||||||
available_scopes=token.scopes,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@execution_router.post(
|
|
||||||
"/agents/{agent_id}/execute",
|
|
||||||
response_model=ExecuteAgentResponse,
|
|
||||||
)
|
|
||||||
async def execute_agent(
|
|
||||||
agent_id: str,
|
|
||||||
request: ExecuteAgentRequest,
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
|
||||||
) -> ExecuteAgentResponse:
|
|
||||||
"""
|
|
||||||
Execute an agent using granted credentials.
|
|
||||||
|
|
||||||
The agent must be accessible to the user, and the client must have
|
|
||||||
valid credential grants that satisfy the agent's requirements.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: The agent (graph) ID to execute
|
|
||||||
request: Execution parameters including inputs and optional grant IDs
|
|
||||||
"""
|
|
||||||
# Verify the agent exists and user has access
|
|
||||||
# First try to get the latest version
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=agent_id,
|
|
||||||
version=None,
|
|
||||||
user_id=token.user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
# Try to find it in the store (public agents)
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=agent_id,
|
|
||||||
version=None,
|
|
||||||
user_id=None,
|
|
||||||
skip_access_check=True,
|
|
||||||
)
|
|
||||||
if not graph:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Agent {agent_id} not found or not accessible",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the grant resolver to validate grants exist
|
|
||||||
# The resolver context will be passed to the execution engine
|
|
||||||
grant_resolver_context = None
|
|
||||||
try:
|
|
||||||
resolver = await create_resolver_from_oauth_token(
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_public_id=token.client_id,
|
|
||||||
grant_ids=request.grant_ids,
|
|
||||||
)
|
|
||||||
# Get available credentials info to build resolver context
|
|
||||||
credentials_info = await resolver.get_available_credentials()
|
|
||||||
grant_resolver_context = GrantResolverContext(
|
|
||||||
client_db_id=resolver.client_id,
|
|
||||||
grant_ids=[c["grant_id"] for c in credentials_info],
|
|
||||||
)
|
|
||||||
except GrantValidationError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"Grant validation failed: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Build execution context with grant resolver info
|
|
||||||
execution_context = ExecutionContext(
|
|
||||||
grant_resolver_context=grant_resolver_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the agent with grant resolver context
|
|
||||||
graph_exec = await add_graph_execution(
|
|
||||||
graph_id=agent_id,
|
|
||||||
user_id=token.user_id,
|
|
||||||
inputs=request.inputs,
|
|
||||||
graph_version=graph.version,
|
|
||||||
execution_context=execution_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log the execution for audit
|
|
||||||
logger.info(
|
|
||||||
f"External execution started: agent={agent_id}, "
|
|
||||||
f"execution={graph_exec.id}, client={token.client_id}, "
|
|
||||||
f"user={token.user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register webhook if provided
|
|
||||||
if request.webhook_url:
|
|
||||||
# Get client to check webhook domains
|
|
||||||
client = await prisma.oauthclient.find_unique(
|
|
||||||
where={"clientId": token.client_id}
|
|
||||||
)
|
|
||||||
if client:
|
|
||||||
if not validate_webhook_url(request.webhook_url, client.webhookDomains):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Webhook URL not in allowed domains for this client",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store webhook registration with client's webhook secret
|
|
||||||
await prisma.executionwebhook.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"executionId": graph_exec.id,
|
|
||||||
"webhookUrl": request.webhook_url,
|
|
||||||
"clientId": client.id,
|
|
||||||
"userId": token.user_id,
|
|
||||||
"secret": client.webhookSecret,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Registered webhook for execution {graph_exec.id}: {request.webhook_url}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ExecuteAgentResponse(
|
|
||||||
execution_id=graph_exec.id,
|
|
||||||
status="queued",
|
|
||||||
message="Agent execution has been queued",
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
# Client error - invalid input or configuration
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid execution request: agent={agent_id}, "
|
|
||||||
f"client={token.client_id}, error={str(e)}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid request: {str(e)}",
|
|
||||||
)
|
|
||||||
except HTTPException:
|
|
||||||
# Re-raise HTTP exceptions as-is
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
# Server error - log full exception but don't expose details to client
|
|
||||||
logger.exception(
|
|
||||||
f"Unexpected error starting execution: agent={agent_id}, "
|
|
||||||
f"client={token.client_id}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="An internal error occurred while starting execution",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@execution_router.get(
|
|
||||||
"/{execution_id}",
|
|
||||||
response_model=ExecutionStatusResponse,
|
|
||||||
)
|
|
||||||
async def get_execution_status(
|
|
||||||
execution_id: str,
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
|
||||||
) -> ExecutionStatusResponse:
|
|
||||||
"""
|
|
||||||
Get the status of an agent execution.
|
|
||||||
|
|
||||||
Returns current status, outputs (if completed), and any error messages.
|
|
||||||
"""
|
|
||||||
graph_exec = await execution_db.get_graph_execution(
|
|
||||||
user_id=token.user_id,
|
|
||||||
execution_id=execution_id,
|
|
||||||
include_node_executions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not graph_exec:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Execution {execution_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build response
|
|
||||||
outputs = None
|
|
||||||
error = None
|
|
||||||
|
|
||||||
if graph_exec.status == AgentExecutionStatus.COMPLETED:
|
|
||||||
outputs = graph_exec.outputs
|
|
||||||
elif graph_exec.status == AgentExecutionStatus.FAILED:
|
|
||||||
# Get error from execution stats
|
|
||||||
# Note: Currently no standard error field in stats, but could be added
|
|
||||||
error = "Execution failed"
|
|
||||||
|
|
||||||
return ExecutionStatusResponse(
|
|
||||||
execution_id=execution_id,
|
|
||||||
status=graph_exec.status.value,
|
|
||||||
started_at=graph_exec.started_at,
|
|
||||||
completed_at=graph_exec.ended_at,
|
|
||||||
outputs=outputs,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@execution_router.post("/{execution_id}/cancel")
|
|
||||||
async def cancel_execution(
|
|
||||||
execution_id: str,
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("agents:execute")),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Cancel a running agent execution.
|
|
||||||
|
|
||||||
Only executions in QUEUED or RUNNING status can be cancelled.
|
|
||||||
"""
|
|
||||||
graph_exec = await execution_db.get_graph_execution(
|
|
||||||
user_id=token.user_id,
|
|
||||||
execution_id=execution_id,
|
|
||||||
include_node_executions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not graph_exec:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Execution {execution_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if execution can be cancelled
|
|
||||||
if graph_exec.status not in [
|
|
||||||
AgentExecutionStatus.QUEUED,
|
|
||||||
AgentExecutionStatus.RUNNING,
|
|
||||||
]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Cannot cancel execution with status {graph_exec.status.value}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update execution status to TERMINATED
|
|
||||||
# Note: This is a simplified implementation. A full implementation would
|
|
||||||
# need to signal the executor to stop processing.
|
|
||||||
await prisma.agentgraphexecution.update(
|
|
||||||
where={"id": execution_id},
|
|
||||||
data={"executionStatus": AgentExecutionStatus.TERMINATED},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Execution terminated: execution={execution_id}, "
|
|
||||||
f"client={token.client_id}, user={token.user_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"message": "Execution terminated", "execution_id": execution_id}
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
"""
|
|
||||||
Credential Grants endpoints for external OAuth clients.
|
|
||||||
|
|
||||||
Allows external applications to:
|
|
||||||
- List their credential grants (metadata only, NOT credential values)
|
|
||||||
- Get grant details
|
|
||||||
- Delete credentials via grants (if permitted)
|
|
||||||
|
|
||||||
Credentials are NEVER returned to external applications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Security
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data import credential_grants as grants_db
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
|
||||||
from backend.server.external.oauth_middleware import OAuthTokenInfo, require_scope
|
|
||||||
|
|
||||||
grants_router = APIRouter(prefix="/grants", tags=["grants"])
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Response Models
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class GrantSummary(BaseModel):
|
|
||||||
"""Summary of a credential grant (returned in list endpoints)."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: str
|
|
||||||
granted_scopes: list[str]
|
|
||||||
permissions: list[str]
|
|
||||||
created_at: datetime
|
|
||||||
last_used_at: Optional[datetime] = None
|
|
||||||
expires_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class GrantDetail(BaseModel):
|
|
||||||
"""Detailed grant information."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: str
|
|
||||||
credential_id: str
|
|
||||||
granted_scopes: list[str]
|
|
||||||
permissions: list[str]
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
last_used_at: Optional[datetime] = None
|
|
||||||
expires_at: Optional[datetime] = None
|
|
||||||
revoked_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Endpoints
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@grants_router.get("/", response_model=list[GrantSummary])
|
|
||||||
async def list_grants(
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
|
||||||
) -> list[GrantSummary]:
|
|
||||||
"""
|
|
||||||
List all active credential grants for this client and user.
|
|
||||||
|
|
||||||
Returns grant metadata but NOT credential values.
|
|
||||||
Credentials are never exposed to external applications.
|
|
||||||
"""
|
|
||||||
# Get the OAuth client's database ID from the public client_id
|
|
||||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid client")
|
|
||||||
|
|
||||||
grants = await grants_db.get_grants_for_user_client(
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_id=client.id,
|
|
||||||
include_revoked=False,
|
|
||||||
include_expired=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
GrantSummary(
|
|
||||||
id=grant.id,
|
|
||||||
provider=grant.provider,
|
|
||||||
granted_scopes=grant.grantedScopes,
|
|
||||||
permissions=[p.value for p in grant.permissions],
|
|
||||||
created_at=grant.createdAt,
|
|
||||||
last_used_at=grant.lastUsedAt,
|
|
||||||
expires_at=grant.expiresAt,
|
|
||||||
)
|
|
||||||
for grant in grants
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@grants_router.get("/{grant_id}", response_model=GrantDetail)
|
|
||||||
async def get_grant(
|
|
||||||
grant_id: str,
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("integrations:list")),
|
|
||||||
) -> GrantDetail:
|
|
||||||
"""
|
|
||||||
Get detailed information about a specific grant.
|
|
||||||
|
|
||||||
Returns grant metadata including scopes and permissions.
|
|
||||||
Does NOT return the credential value.
|
|
||||||
"""
|
|
||||||
# Get the OAuth client's database ID
|
|
||||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid client")
|
|
||||||
|
|
||||||
grant = await grants_db.get_credential_grant(
|
|
||||||
grant_id=grant_id,
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_id=client.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not grant:
|
|
||||||
raise HTTPException(status_code=404, detail="Grant not found")
|
|
||||||
|
|
||||||
# Check if expired
|
|
||||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
|
||||||
raise HTTPException(status_code=404, detail="Grant has expired")
|
|
||||||
|
|
||||||
# Check if revoked
|
|
||||||
if grant.revokedAt:
|
|
||||||
raise HTTPException(status_code=404, detail="Grant has been revoked")
|
|
||||||
|
|
||||||
return GrantDetail(
|
|
||||||
id=grant.id,
|
|
||||||
provider=grant.provider,
|
|
||||||
credential_id=grant.credentialId,
|
|
||||||
granted_scopes=grant.grantedScopes,
|
|
||||||
permissions=[p.value for p in grant.permissions],
|
|
||||||
created_at=grant.createdAt,
|
|
||||||
updated_at=grant.updatedAt,
|
|
||||||
last_used_at=grant.lastUsedAt,
|
|
||||||
expires_at=grant.expiresAt,
|
|
||||||
revoked_at=grant.revokedAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@grants_router.delete("/{grant_id}/credential")
|
|
||||||
async def delete_credential_via_grant(
|
|
||||||
grant_id: str,
|
|
||||||
token: OAuthTokenInfo = Security(require_scope("integrations:delete")),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Delete the underlying credential associated with a grant.
|
|
||||||
|
|
||||||
This requires the grant to have the DELETE permission.
|
|
||||||
Deleting the credential also invalidates all grants for that credential.
|
|
||||||
"""
|
|
||||||
from prisma.enums import CredentialGrantPermission
|
|
||||||
|
|
||||||
# Get the OAuth client's database ID
|
|
||||||
client = await prisma.oauthclient.find_unique(where={"clientId": token.client_id})
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid client")
|
|
||||||
|
|
||||||
# Get the grant
|
|
||||||
grant = await grants_db.get_credential_grant(
|
|
||||||
grant_id=grant_id,
|
|
||||||
user_id=token.user_id,
|
|
||||||
client_id=client.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not grant:
|
|
||||||
raise HTTPException(status_code=404, detail="Grant not found")
|
|
||||||
|
|
||||||
# Check if grant is valid
|
|
||||||
if grant.revokedAt:
|
|
||||||
raise HTTPException(status_code=400, detail="Grant has been revoked")
|
|
||||||
|
|
||||||
if grant.expiresAt and grant.expiresAt < datetime.now(timezone.utc):
|
|
||||||
raise HTTPException(status_code=400, detail="Grant has expired")
|
|
||||||
|
|
||||||
# Check DELETE permission
|
|
||||||
if CredentialGrantPermission.DELETE not in grant.permissions:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="Grant does not have DELETE permission for this credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delete the credential using the credentials store
|
|
||||||
try:
|
|
||||||
creds_store = IntegrationCredentialsStore()
|
|
||||||
await creds_store.delete_creds_by_id(
|
|
||||||
user_id=token.user_id,
|
|
||||||
credentials_id=grant.credentialId,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Failed to delete credential: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Revoke all grants for this credential
|
|
||||||
await grants_db.revoke_grants_for_credential(
|
|
||||||
user_id=token.user_id,
|
|
||||||
credential_id=grant.credentialId,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"message": "Credential deleted successfully"}
|
|
||||||
650
autogpt_platform/backend/backend/server/external/routes/integrations.py
vendored
Normal file
650
autogpt_platform/backend/backend/server/external/routes/integrations.py
vendored
Normal file
@@ -0,0 +1,650 @@
|
|||||||
|
"""
|
||||||
|
External API endpoints for integrations and credentials.
|
||||||
|
|
||||||
|
This module provides endpoints for external applications (like Autopilot) to:
|
||||||
|
- Initiate OAuth flows with custom callback URLs
|
||||||
|
- Complete OAuth flows by exchanging authorization codes
|
||||||
|
- Create API key, user/password, and host-scoped credentials
|
||||||
|
- List and manage user credentials
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, HTTPException, Path, Security, status
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from backend.data.api_key import APIKeyInfo
|
||||||
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
Credentials,
|
||||||
|
CredentialsType,
|
||||||
|
HostScopedCredentials,
|
||||||
|
OAuth2Credentials,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.server.external.middleware import require_permission
|
||||||
|
from backend.server.integrations.models import get_all_provider_names
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.integrations.oauth import BaseOAuthHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = Settings()
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
|
||||||
|
integrations_router = APIRouter(prefix="/integrations", tags=["integrations"])
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Request/Response Models ==================== #
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthInitiateRequest(BaseModel):
|
||||||
|
"""Request model for initiating an OAuth flow."""
|
||||||
|
|
||||||
|
callback_url: str = Field(
|
||||||
|
..., description="The external app's callback URL for OAuth redirect"
|
||||||
|
)
|
||||||
|
scopes: list[str] = Field(
|
||||||
|
default_factory=list, description="OAuth scopes to request"
|
||||||
|
)
|
||||||
|
state_metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Arbitrary metadata to echo back on completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthInitiateResponse(BaseModel):
|
||||||
|
"""Response model for OAuth initiation."""
|
||||||
|
|
||||||
|
login_url: str = Field(..., description="URL to redirect user for OAuth consent")
|
||||||
|
state_token: str = Field(..., description="State token for CSRF protection")
|
||||||
|
expires_at: int = Field(
|
||||||
|
..., description="Unix timestamp when the state token expires"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCompleteRequest(BaseModel):
|
||||||
|
"""Request model for completing an OAuth flow."""
|
||||||
|
|
||||||
|
code: str = Field(..., description="Authorization code from OAuth provider")
|
||||||
|
state_token: str = Field(..., description="State token from initiate request")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCompleteResponse(BaseModel):
|
||||||
|
"""Response model for OAuth completion."""
|
||||||
|
|
||||||
|
credentials_id: str = Field(..., description="ID of the stored credentials")
|
||||||
|
provider: str = Field(..., description="Provider name")
|
||||||
|
type: str = Field(..., description="Credential type (oauth2)")
|
||||||
|
title: Optional[str] = Field(None, description="Credential title")
|
||||||
|
scopes: list[str] = Field(default_factory=list, description="Granted scopes")
|
||||||
|
username: Optional[str] = Field(None, description="Username from provider")
|
||||||
|
state_metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Echoed metadata from initiate request"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialSummary(BaseModel):
|
||||||
|
"""Summary of a credential without sensitive data."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: str
|
||||||
|
type: CredentialsType
|
||||||
|
title: Optional[str] = None
|
||||||
|
scopes: Optional[list[str]] = None
|
||||||
|
username: Optional[str] = None
|
||||||
|
host: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderInfo(BaseModel):
|
||||||
|
"""Information about an integration provider."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
supports_oauth: bool = False
|
||||||
|
supports_api_key: bool = False
|
||||||
|
supports_user_password: bool = False
|
||||||
|
supports_host_scoped: bool = False
|
||||||
|
default_scopes: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Credential Creation Models ==================== #
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAPIKeyCredentialRequest(BaseModel):
|
||||||
|
"""Request model for creating API key credentials."""
|
||||||
|
|
||||||
|
type: Literal["api_key"] = "api_key"
|
||||||
|
api_key: str = Field(..., description="The API key")
|
||||||
|
title: str = Field(..., description="A name for this credential")
|
||||||
|
expires_at: Optional[int] = Field(
|
||||||
|
None, description="Unix timestamp when the API key expires"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUserPasswordCredentialRequest(BaseModel):
|
||||||
|
"""Request model for creating username/password credentials."""
|
||||||
|
|
||||||
|
type: Literal["user_password"] = "user_password"
|
||||||
|
username: str = Field(..., description="Username")
|
||||||
|
password: str = Field(..., description="Password")
|
||||||
|
title: str = Field(..., description="A name for this credential")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateHostScopedCredentialRequest(BaseModel):
|
||||||
|
"""Request model for creating host-scoped credentials."""
|
||||||
|
|
||||||
|
type: Literal["host_scoped"] = "host_scoped"
|
||||||
|
host: str = Field(..., description="Host/domain pattern to match")
|
||||||
|
headers: dict[str, str] = Field(..., description="Headers to include in requests")
|
||||||
|
title: str = Field(..., description="A name for this credential")
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for credential creation
|
||||||
|
CreateCredentialRequest = Annotated[
|
||||||
|
CreateAPIKeyCredentialRequest
|
||||||
|
| CreateUserPasswordCredentialRequest
|
||||||
|
| CreateHostScopedCredentialRequest,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCredentialResponse(BaseModel):
|
||||||
|
"""Response model for credential creation."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: str
|
||||||
|
type: CredentialsType
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Helper Functions ==================== #
|
||||||
|
|
||||||
|
|
||||||
|
def validate_callback_url(callback_url: str) -> bool:
|
||||||
|
"""Validate that the callback URL is from an allowed origin."""
|
||||||
|
allowed_origins = settings.config.external_oauth_callback_origins
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(callback_url)
|
||||||
|
callback_origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
for allowed in allowed_origins:
|
||||||
|
# Simple origin matching
|
||||||
|
if callback_origin == allowed:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Allow localhost with any port in development (proper hostname check)
|
||||||
|
if parsed.hostname == "localhost":
|
||||||
|
for allowed in allowed_origins:
|
||||||
|
allowed_parsed = urlparse(allowed)
|
||||||
|
if allowed_parsed.hostname == "localhost":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_oauth_handler_for_external(
|
||||||
|
provider_name: str, redirect_uri: str
|
||||||
|
) -> "BaseOAuthHandler":
|
||||||
|
"""Get an OAuth handler configured with an external redirect URI."""
|
||||||
|
# Ensure blocks are loaded so SDK providers are available
|
||||||
|
try:
|
||||||
|
from backend.blocks import load_all_blocks
|
||||||
|
|
||||||
|
load_all_blocks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load blocks: {e}")
|
||||||
|
|
||||||
|
if provider_name not in HANDLERS_BY_NAME:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Provider '{provider_name}' does not support OAuth",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this provider has custom OAuth credentials
|
||||||
|
oauth_credentials = CREDENTIALS_BY_PROVIDER.get(provider_name)
|
||||||
|
|
||||||
|
if oauth_credentials and not oauth_credentials.use_secrets:
|
||||||
|
import os
|
||||||
|
|
||||||
|
client_id = (
|
||||||
|
os.getenv(oauth_credentials.client_id_env_var)
|
||||||
|
if oauth_credentials.client_id_env_var
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
client_secret = (
|
||||||
|
os.getenv(oauth_credentials.client_secret_env_var)
|
||||||
|
if oauth_credentials.client_secret_env_var
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client_id = getattr(settings.secrets, f"{provider_name}_client_id", None)
|
||||||
|
client_secret = getattr(
|
||||||
|
settings.secrets, f"{provider_name}_client_secret", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (client_id and client_secret):
|
||||||
|
logger.error(f"Attempt to use unconfigured {provider_name} OAuth integration")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail={
|
||||||
|
"message": f"Integration with provider '{provider_name}' is not configured.",
|
||||||
|
"hint": "Set client ID and secret in the application's deployment environment",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||||
|
return handler_class(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Endpoints ==================== #
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.get("/providers", response_model=list[ProviderInfo])
|
||||||
|
async def list_providers(
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> list[ProviderInfo]:
|
||||||
|
"""
|
||||||
|
List all available integration providers.
|
||||||
|
|
||||||
|
Returns a list of all providers with their supported credential types.
|
||||||
|
Most providers support API key credentials, and some also support OAuth.
|
||||||
|
"""
|
||||||
|
# Ensure blocks are loaded
|
||||||
|
try:
|
||||||
|
from backend.blocks import load_all_blocks
|
||||||
|
|
||||||
|
load_all_blocks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load blocks: {e}")
|
||||||
|
|
||||||
|
from backend.sdk.registry import AutoRegistry
|
||||||
|
|
||||||
|
providers = []
|
||||||
|
for name in get_all_provider_names():
|
||||||
|
supports_oauth = name in HANDLERS_BY_NAME
|
||||||
|
handler_class = HANDLERS_BY_NAME.get(name)
|
||||||
|
default_scopes = (
|
||||||
|
getattr(handler_class, "DEFAULT_SCOPES", []) if handler_class else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if provider has specific auth types from SDK registration
|
||||||
|
sdk_provider = AutoRegistry.get_provider(name)
|
||||||
|
if sdk_provider and sdk_provider.supported_auth_types:
|
||||||
|
supports_api_key = "api_key" in sdk_provider.supported_auth_types
|
||||||
|
supports_user_password = (
|
||||||
|
"user_password" in sdk_provider.supported_auth_types
|
||||||
|
)
|
||||||
|
supports_host_scoped = "host_scoped" in sdk_provider.supported_auth_types
|
||||||
|
else:
|
||||||
|
# Fallback for legacy providers
|
||||||
|
supports_api_key = True # All providers can accept API keys
|
||||||
|
supports_user_password = name in ("smtp",)
|
||||||
|
supports_host_scoped = name == "http"
|
||||||
|
|
||||||
|
providers.append(
|
||||||
|
ProviderInfo(
|
||||||
|
name=name,
|
||||||
|
supports_oauth=supports_oauth,
|
||||||
|
supports_api_key=supports_api_key,
|
||||||
|
supports_user_password=supports_user_password,
|
||||||
|
supports_host_scoped=supports_host_scoped,
|
||||||
|
default_scopes=default_scopes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.post(
|
||||||
|
"/{provider}/oauth/initiate",
|
||||||
|
response_model=OAuthInitiateResponse,
|
||||||
|
summary="Initiate OAuth flow",
|
||||||
|
)
|
||||||
|
async def initiate_oauth(
|
||||||
|
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||||
|
request: OAuthInitiateRequest,
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> OAuthInitiateResponse:
|
||||||
|
"""
|
||||||
|
Initiate an OAuth flow for an external application.
|
||||||
|
|
||||||
|
This endpoint allows external apps to start an OAuth flow with a custom
|
||||||
|
callback URL. The callback URL must be from an allowed origin configured
|
||||||
|
in the platform settings.
|
||||||
|
|
||||||
|
Returns a login URL to redirect the user to, along with a state token
|
||||||
|
for CSRF protection.
|
||||||
|
"""
|
||||||
|
# Validate callback URL
|
||||||
|
if not validate_callback_url(request.callback_url):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Callback URL origin is not allowed. Allowed origins: {settings.config.external_oauth_callback_origins}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate provider
|
||||||
|
try:
|
||||||
|
provider_name = ProviderName(provider)
|
||||||
|
except ValueError:
|
||||||
|
# Check if it's a dynamically registered provider
|
||||||
|
if provider not in HANDLERS_BY_NAME:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Provider '{provider}' not found",
|
||||||
|
)
|
||||||
|
provider_name = provider
|
||||||
|
|
||||||
|
# Get OAuth handler with external callback URL
|
||||||
|
handler = _get_oauth_handler_for_external(
|
||||||
|
provider if isinstance(provider_name, str) else provider_name.value,
|
||||||
|
request.callback_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store state token with external flow metadata
|
||||||
|
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
provider=provider if isinstance(provider_name, str) else provider_name.value,
|
||||||
|
scopes=request.scopes,
|
||||||
|
callback_url=request.callback_url,
|
||||||
|
state_metadata=request.state_metadata,
|
||||||
|
initiated_by_api_key_id=api_key.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build login URL
|
||||||
|
login_url = handler.get_login_url(
|
||||||
|
request.scopes, state_token, code_challenge=code_challenge
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate expiration (10 minutes from now)
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
expires_at = int((datetime.now(timezone.utc) + timedelta(minutes=10)).timestamp())
|
||||||
|
|
||||||
|
return OAuthInitiateResponse(
|
||||||
|
login_url=login_url,
|
||||||
|
state_token=state_token,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.post(
|
||||||
|
"/{provider}/oauth/complete",
|
||||||
|
response_model=OAuthCompleteResponse,
|
||||||
|
summary="Complete OAuth flow",
|
||||||
|
)
|
||||||
|
async def complete_oauth(
|
||||||
|
provider: Annotated[str, Path(title="The OAuth provider")],
|
||||||
|
request: OAuthCompleteRequest,
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> OAuthCompleteResponse:
|
||||||
|
"""
|
||||||
|
Complete an OAuth flow by exchanging the authorization code for tokens.
|
||||||
|
|
||||||
|
This endpoint should be called after the user has authorized the application
|
||||||
|
and been redirected back to the external app's callback URL with an
|
||||||
|
authorization code.
|
||||||
|
"""
|
||||||
|
# Verify state token
|
||||||
|
valid_state = await creds_manager.store.verify_state_token(
|
||||||
|
api_key.user_id, request.state_token, provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if not valid_state:
|
||||||
|
logger.warning(f"Invalid or expired state token for provider {provider}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired state token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify this is an external flow (callback_url must be set)
|
||||||
|
if not valid_state.callback_url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="State token was not created for external OAuth flow",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get OAuth handler with the original callback URL
|
||||||
|
handler = _get_oauth_handler_for_external(provider, valid_state.callback_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
scopes = valid_state.scopes
|
||||||
|
scopes = handler.handle_default_scopes(scopes)
|
||||||
|
|
||||||
|
credentials = await handler.exchange_code_for_tokens(
|
||||||
|
request.code, scopes, valid_state.code_verifier
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle Linear's space-separated scopes
|
||||||
|
if len(credentials.scopes) == 1 and " " in credentials.scopes[0]:
|
||||||
|
credentials.scopes = credentials.scopes[0].split(" ")
|
||||||
|
|
||||||
|
# Check scope mismatch
|
||||||
|
if not set(scopes).issubset(set(credentials.scopes)):
|
||||||
|
logger.warning(
|
||||||
|
f"Granted scopes {credentials.scopes} for provider {provider} "
|
||||||
|
f"do not include all requested scopes {scopes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth2 Code->Token exchange failed for provider {provider}: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"OAuth2 callback failed to exchange code for tokens: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store credentials
|
||||||
|
await creds_manager.create(api_key.user_id, credentials)
|
||||||
|
|
||||||
|
logger.info(f"Successfully completed external OAuth for provider {provider}")
|
||||||
|
|
||||||
|
return OAuthCompleteResponse(
|
||||||
|
credentials_id=credentials.id,
|
||||||
|
provider=credentials.provider,
|
||||||
|
type=credentials.type,
|
||||||
|
title=credentials.title,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
username=credentials.username,
|
||||||
|
state_metadata=valid_state.state_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.get("/credentials", response_model=list[CredentialSummary])
|
||||||
|
async def list_credentials(
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> list[CredentialSummary]:
|
||||||
|
"""
|
||||||
|
List all credentials for the authenticated user.
|
||||||
|
|
||||||
|
Returns metadata about each credential without exposing sensitive tokens.
|
||||||
|
"""
|
||||||
|
credentials = await creds_manager.store.get_all_creds(api_key.user_id)
|
||||||
|
return [
|
||||||
|
CredentialSummary(
|
||||||
|
id=cred.id,
|
||||||
|
provider=cred.provider,
|
||||||
|
type=cred.type,
|
||||||
|
title=cred.title,
|
||||||
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
|
)
|
||||||
|
for cred in credentials
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.get(
|
||||||
|
"/{provider}/credentials", response_model=list[CredentialSummary]
|
||||||
|
)
|
||||||
|
async def list_credentials_by_provider(
|
||||||
|
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.READ_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> list[CredentialSummary]:
|
||||||
|
"""
|
||||||
|
List credentials for a specific provider.
|
||||||
|
"""
|
||||||
|
credentials = await creds_manager.store.get_creds_by_provider(
|
||||||
|
api_key.user_id, provider
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
CredentialSummary(
|
||||||
|
id=cred.id,
|
||||||
|
provider=cred.provider,
|
||||||
|
type=cred.type,
|
||||||
|
title=cred.title,
|
||||||
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
|
)
|
||||||
|
for cred in credentials
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.post(
|
||||||
|
"/{provider}/credentials",
|
||||||
|
response_model=CreateCredentialResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create credentials",
|
||||||
|
)
|
||||||
|
async def create_credential(
|
||||||
|
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||||
|
request: Union[
|
||||||
|
CreateAPIKeyCredentialRequest,
|
||||||
|
CreateUserPasswordCredentialRequest,
|
||||||
|
CreateHostScopedCredentialRequest,
|
||||||
|
] = Body(..., discriminator="type"),
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.MANAGE_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> CreateCredentialResponse:
|
||||||
|
"""
|
||||||
|
Create non-OAuth credentials for a provider.
|
||||||
|
|
||||||
|
Supports creating:
|
||||||
|
- API key credentials (type: "api_key")
|
||||||
|
- Username/password credentials (type: "user_password")
|
||||||
|
- Host-scoped credentials (type: "host_scoped")
|
||||||
|
|
||||||
|
For OAuth credentials, use the OAuth initiate/complete flow instead.
|
||||||
|
"""
|
||||||
|
# Validate provider exists
|
||||||
|
all_providers = get_all_provider_names()
|
||||||
|
if provider not in all_providers:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Provider '{provider}' not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the appropriate credential type
|
||||||
|
credentials: Credentials
|
||||||
|
if request.type == "api_key":
|
||||||
|
credentials = APIKeyCredentials(
|
||||||
|
provider=provider,
|
||||||
|
api_key=SecretStr(request.api_key),
|
||||||
|
title=request.title,
|
||||||
|
expires_at=request.expires_at,
|
||||||
|
)
|
||||||
|
elif request.type == "user_password":
|
||||||
|
credentials = UserPasswordCredentials(
|
||||||
|
provider=provider,
|
||||||
|
username=SecretStr(request.username),
|
||||||
|
password=SecretStr(request.password),
|
||||||
|
title=request.title,
|
||||||
|
)
|
||||||
|
elif request.type == "host_scoped":
|
||||||
|
# Convert string headers to SecretStr
|
||||||
|
secret_headers = {k: SecretStr(v) for k, v in request.headers.items()}
|
||||||
|
credentials = HostScopedCredentials(
|
||||||
|
provider=provider,
|
||||||
|
host=request.host,
|
||||||
|
headers=secret_headers,
|
||||||
|
title=request.title,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported credential type: {request.type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store credentials
|
||||||
|
try:
|
||||||
|
await creds_manager.create(api_key.user_id, credentials)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store credentials: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to store credentials: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created {request.type} credentials for provider {provider}")
|
||||||
|
|
||||||
|
return CreateCredentialResponse(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=provider,
|
||||||
|
type=credentials.type,
|
||||||
|
title=credentials.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteCredentialResponse(BaseModel):
|
||||||
|
"""Response model for deleting a credential."""
|
||||||
|
|
||||||
|
deleted: bool = Field(..., description="Whether the credential was deleted")
|
||||||
|
credentials_id: str = Field(..., description="ID of the deleted credential")
|
||||||
|
|
||||||
|
|
||||||
|
@integrations_router.delete(
|
||||||
|
"/{provider}/credentials/{cred_id}",
|
||||||
|
response_model=DeleteCredentialResponse,
|
||||||
|
)
|
||||||
|
async def delete_credential(
|
||||||
|
provider: Annotated[str, Path(title="The provider")],
|
||||||
|
cred_id: Annotated[str, Path(title="The credential ID to delete")],
|
||||||
|
api_key: APIKeyInfo = Security(
|
||||||
|
require_permission(APIKeyPermission.DELETE_INTEGRATIONS)
|
||||||
|
),
|
||||||
|
) -> DeleteCredentialResponse:
|
||||||
|
"""
|
||||||
|
Delete a credential.
|
||||||
|
|
||||||
|
Note: This does not revoke the tokens with the provider. For full cleanup,
|
||||||
|
use the main API's delete endpoint which handles webhook cleanup and
|
||||||
|
token revocation.
|
||||||
|
"""
|
||||||
|
creds = await creds_manager.store.get_creds_by_id(api_key.user_id, cred_id)
|
||||||
|
if not creds:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Credentials not found"
|
||||||
|
)
|
||||||
|
if creds.provider != provider:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Credentials do not match the specified provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
await creds_manager.delete(api_key.user_id, cred_id)
|
||||||
|
|
||||||
|
return DeleteCredentialResponse(deleted=True, credentials_id=cred_id)
|
||||||
148
autogpt_platform/backend/backend/server/external/routes/tools.py
vendored
Normal file
148
autogpt_platform/backend/backend/server/external/routes/tools.py
vendored
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""External API routes for chat tools - stateless HTTP endpoints.
|
||||||
|
|
||||||
|
Note: These endpoints use ephemeral sessions that are not persisted to Redis.
|
||||||
|
As a result, session-based rate limiting (max_agent_runs, max_agent_schedules)
|
||||||
|
is not enforced for external API calls. Each request creates a fresh session
|
||||||
|
with zeroed counters. Rate limiting for external API consumers should be
|
||||||
|
handled separately (e.g., via API key quotas).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Security
|
||||||
|
from prisma.enums import APIKeyPermission
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.data.api_key import APIKeyInfo
|
||||||
|
from backend.server.external.middleware import require_permission
|
||||||
|
from backend.server.v2.chat.model import ChatSession
|
||||||
|
from backend.server.v2.chat.tools import find_agent_tool, run_agent_tool
|
||||||
|
from backend.server.v2.chat.tools.models import ToolResponseBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
tools_router = APIRouter(prefix="/tools", tags=["tools"])
|
||||||
|
|
||||||
|
# Note: We use Security() as a function parameter dependency (api_key: APIKeyInfo = Security(...))
|
||||||
|
# rather than in the decorator's dependencies= list. This avoids duplicate permission checks
|
||||||
|
# while still enforcing auth AND giving us access to the api_key for extracting user_id.
|
||||||
|
|
||||||
|
|
||||||
|
# Request models
|
||||||
|
class FindAgentRequest(BaseModel):
|
||||||
|
query: str = Field(..., description="Search query for finding agents")
|
||||||
|
|
||||||
|
|
||||||
|
class RunAgentRequest(BaseModel):
|
||||||
|
"""Request to run or schedule an agent.
|
||||||
|
|
||||||
|
The tool automatically handles the setup flow:
|
||||||
|
- First call returns available inputs so user can decide what values to use
|
||||||
|
- Returns missing credentials if user needs to configure them
|
||||||
|
- Executes when inputs are provided OR use_defaults=true
|
||||||
|
- Schedules execution if schedule_name and cron are provided
|
||||||
|
"""
|
||||||
|
|
||||||
|
username_agent_slug: str = Field(
|
||||||
|
...,
|
||||||
|
description="The marketplace agent slug (e.g., 'username/agent-name')",
|
||||||
|
)
|
||||||
|
inputs: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Dictionary of input values for the agent",
|
||||||
|
)
|
||||||
|
use_defaults: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Set to true to run with default values (user must confirm)",
|
||||||
|
)
|
||||||
|
schedule_name: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Name for scheduled execution (triggers scheduling mode)",
|
||||||
|
)
|
||||||
|
cron: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="Cron expression (5 fields: minute hour day month weekday)",
|
||||||
|
)
|
||||||
|
timezone: str = Field(
|
||||||
|
default="UTC",
|
||||||
|
description="IANA timezone (e.g., 'America/New_York', 'UTC')",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_ephemeral_session(user_id: str | None) -> ChatSession:
|
||||||
|
"""Create an ephemeral session for stateless API requests."""
|
||||||
|
return ChatSession.new(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_router.post(
|
||||||
|
path="/find-agent",
|
||||||
|
)
|
||||||
|
async def find_agent(
|
||||||
|
request: FindAgentRequest,
|
||||||
|
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Search for agents in the marketplace based on capabilities and user needs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Search query for finding agents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching agents or no results response
|
||||||
|
"""
|
||||||
|
session = _create_ephemeral_session(api_key.user_id)
|
||||||
|
result = await find_agent_tool._execute(
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
session=session,
|
||||||
|
query=request.query,
|
||||||
|
)
|
||||||
|
return _response_to_dict(result)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_router.post(
|
||||||
|
path="/run-agent",
|
||||||
|
)
|
||||||
|
async def run_agent(
|
||||||
|
request: RunAgentRequest,
|
||||||
|
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.USE_TOOLS)),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run or schedule an agent from the marketplace.
|
||||||
|
|
||||||
|
The endpoint automatically handles the setup flow:
|
||||||
|
- Returns missing inputs if required fields are not provided
|
||||||
|
- Returns missing credentials if user needs to configure them
|
||||||
|
- Executes immediately if all requirements are met
|
||||||
|
- Schedules execution if schedule_name and cron are provided
|
||||||
|
|
||||||
|
For scheduled execution:
|
||||||
|
- Cron format: "minute hour day month weekday"
|
||||||
|
- Examples: "0 9 * * 1-5" (9am weekdays), "0 0 * * *" (daily at midnight)
|
||||||
|
- Timezone: Use IANA timezone names like "America/New_York"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Agent slug, inputs, and optional schedule config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- setup_requirements: If inputs or credentials are missing
|
||||||
|
- execution_started: If agent was run or scheduled successfully
|
||||||
|
- error: If something went wrong
|
||||||
|
"""
|
||||||
|
session = _create_ephemeral_session(api_key.user_id)
|
||||||
|
result = await run_agent_tool._execute(
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
session=session,
|
||||||
|
username_agent_slug=request.username_agent_slug,
|
||||||
|
inputs=request.inputs,
|
||||||
|
use_defaults=request.use_defaults,
|
||||||
|
schedule_name=request.schedule_name or "",
|
||||||
|
cron=request.cron or "",
|
||||||
|
timezone=request.timezone,
|
||||||
|
)
|
||||||
|
return _response_to_dict(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _response_to_dict(result: ToolResponseBase) -> dict[str, Any]:
|
||||||
|
"""Convert a tool response to a dictionary for JSON serialization."""
|
||||||
|
return result.model_dump()
|
||||||
295
autogpt_platform/backend/backend/server/external/routes/v1.py
vendored
Normal file
295
autogpt_platform/backend/backend/server/external/routes/v1.py
vendored
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
import logging
|
||||||
|
import urllib.parse
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Annotated, Any, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, HTTPException, Security
|
||||||
|
from prisma.enums import AgentExecutionStatus, APIKeyPermission
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
import backend.data.block
|
||||||
|
import backend.server.v2.store.cache as store_cache
|
||||||
|
import backend.server.v2.store.model as store_model
|
||||||
|
from backend.data import execution as execution_db
|
||||||
|
from backend.data import graph as graph_db
|
||||||
|
from backend.data.api_key import APIKeyInfo
|
||||||
|
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||||
|
from backend.executor.utils import add_graph_execution
|
||||||
|
from backend.server.external.middleware import require_permission
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
v1_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class NodeOutput(TypedDict):
|
||||||
|
key: str
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionNode(TypedDict):
|
||||||
|
node_id: str
|
||||||
|
input: Any
|
||||||
|
output: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionNodeOutput(TypedDict):
|
||||||
|
node_id: str
|
||||||
|
outputs: list[NodeOutput]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExecutionResult(TypedDict):
|
||||||
|
execution_id: str
|
||||||
|
status: str
|
||||||
|
nodes: list[ExecutionNode]
|
||||||
|
output: Optional[list[dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/blocks",
|
||||||
|
tags=["blocks"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.READ_BLOCK))],
|
||||||
|
)
|
||||||
|
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||||
|
blocks = [block() for block in backend.data.block.get_blocks().values()]
|
||||||
|
return [b.to_dict() for b in blocks if not b.disabled]
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/blocks/{block_id}/execute",
|
||||||
|
tags=["blocks"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.EXECUTE_BLOCK))],
|
||||||
|
)
|
||||||
|
async def execute_graph_block(
|
||||||
|
block_id: str,
|
||||||
|
data: BlockInput,
|
||||||
|
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||||
|
) -> CompletedBlockOutput:
|
||||||
|
obj = backend.data.block.get_block(block_id)
|
||||||
|
if not obj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||||
|
|
||||||
|
output = defaultdict(list)
|
||||||
|
async for name, data in obj.execute(data):
|
||||||
|
output[name].append(data)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.post(
|
||||||
|
path="/graphs/{graph_id}/execute/{graph_version}",
|
||||||
|
tags=["graphs"],
|
||||||
|
)
|
||||||
|
async def execute_graph(
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||||
|
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
graph_exec = await add_graph_execution(
|
||||||
|
graph_id=graph_id,
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
inputs=node_input,
|
||||||
|
graph_version=graph_version,
|
||||||
|
)
|
||||||
|
return {"id": graph_exec.id}
|
||||||
|
except Exception as e:
|
||||||
|
msg = str(e).encode().decode("unicode_escape")
|
||||||
|
raise HTTPException(status_code=400, detail=msg)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/graphs/{graph_id}/executions/{graph_exec_id}/results",
|
||||||
|
tags=["graphs"],
|
||||||
|
)
|
||||||
|
async def get_graph_execution_results(
|
||||||
|
graph_id: str,
|
||||||
|
graph_exec_id: str,
|
||||||
|
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||||
|
) -> GraphExecutionResult:
|
||||||
|
graph_exec = await execution_db.get_graph_execution(
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
execution_id=graph_exec_id,
|
||||||
|
include_node_executions=True,
|
||||||
|
)
|
||||||
|
if not graph_exec:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await graph_db.get_graph(
|
||||||
|
graph_id=graph_exec.graph_id,
|
||||||
|
version=graph_exec.graph_version,
|
||||||
|
user_id=api_key.user_id,
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||||
|
|
||||||
|
return GraphExecutionResult(
|
||||||
|
execution_id=graph_exec_id,
|
||||||
|
status=graph_exec.status.value,
|
||||||
|
nodes=[
|
||||||
|
ExecutionNode(
|
||||||
|
node_id=node_exec.node_id,
|
||||||
|
input=node_exec.input_data.get("value", node_exec.input_data),
|
||||||
|
output={k: v for k, v in node_exec.output_data.items()},
|
||||||
|
)
|
||||||
|
for node_exec in graph_exec.node_executions
|
||||||
|
],
|
||||||
|
output=(
|
||||||
|
[
|
||||||
|
{name: value}
|
||||||
|
for name, values in graph_exec.outputs.items()
|
||||||
|
for value in values
|
||||||
|
]
|
||||||
|
if graph_exec.status == AgentExecutionStatus.COMPLETED
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
############### Store Endpoints ##############
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/store/agents",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
|
response_model=store_model.StoreAgentsResponse,
|
||||||
|
)
|
||||||
|
async def get_store_agents(
|
||||||
|
featured: bool = False,
|
||||||
|
creator: str | None = None,
|
||||||
|
sorted_by: Literal["rating", "runs", "name", "updated_at"] | None = None,
|
||||||
|
search_query: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> store_model.StoreAgentsResponse:
|
||||||
|
"""
|
||||||
|
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featured: Filter to only show featured agents
|
||||||
|
creator: Filter agents by creator username
|
||||||
|
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at"
|
||||||
|
search_query: Search agents by name, subheading and description
|
||||||
|
category: Filter agents by category
|
||||||
|
page: Page number for pagination (default 1)
|
||||||
|
page_size: Number of agents per page (default 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise HTTPException(status_code=422, detail="Page must be greater than 0")
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
|
||||||
|
|
||||||
|
agents = await store_cache._get_cached_store_agents(
|
||||||
|
featured=featured,
|
||||||
|
creator=creator,
|
||||||
|
sorted_by=sorted_by,
|
||||||
|
search_query=search_query,
|
||||||
|
category=category,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/store/agents/{username}/{agent_name}",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
|
response_model=store_model.StoreAgentDetails,
|
||||||
|
)
|
||||||
|
async def get_store_agent(
|
||||||
|
username: str,
|
||||||
|
agent_name: str,
|
||||||
|
) -> store_model.StoreAgentDetails:
|
||||||
|
"""
|
||||||
|
Get details of a specific store agent by username and agent name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: Creator's username
|
||||||
|
agent_name: Name/slug of the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreAgentDetails: Detailed information about the agent
|
||||||
|
"""
|
||||||
|
username = urllib.parse.unquote(username).lower()
|
||||||
|
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||||
|
agent = await store_cache._get_cached_agent_details(
|
||||||
|
username=username, agent_name=agent_name
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/store/creators",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
|
response_model=store_model.CreatorsResponse,
|
||||||
|
)
|
||||||
|
async def get_store_creators(
|
||||||
|
featured: bool = False,
|
||||||
|
search_query: str | None = None,
|
||||||
|
sorted_by: Literal["agent_rating", "agent_runs", "num_agents"] | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> store_model.CreatorsResponse:
|
||||||
|
"""
|
||||||
|
Get a paginated list of store creators with optional filtering and sorting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featured: Filter to only show featured creators
|
||||||
|
search_query: Search creators by profile description
|
||||||
|
sorted_by: Sort by "agent_rating", "agent_runs", or "num_agents"
|
||||||
|
page: Page number for pagination (default 1)
|
||||||
|
page_size: Number of creators per page (default 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreatorsResponse: Paginated list of creators matching the filters
|
||||||
|
"""
|
||||||
|
if page < 1:
|
||||||
|
raise HTTPException(status_code=422, detail="Page must be greater than 0")
|
||||||
|
|
||||||
|
if page_size < 1:
|
||||||
|
raise HTTPException(status_code=422, detail="Page size must be greater than 0")
|
||||||
|
|
||||||
|
creators = await store_cache._get_cached_store_creators(
|
||||||
|
featured=featured,
|
||||||
|
search_query=search_query,
|
||||||
|
sorted_by=sorted_by,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
return creators
|
||||||
|
|
||||||
|
|
||||||
|
@v1_router.get(
|
||||||
|
path="/store/creators/{username}",
|
||||||
|
tags=["store"],
|
||||||
|
dependencies=[Security(require_permission(APIKeyPermission.READ_STORE))],
|
||||||
|
response_model=store_model.CreatorDetails,
|
||||||
|
)
|
||||||
|
async def get_store_creator(
|
||||||
|
username: str,
|
||||||
|
) -> store_model.CreatorDetails:
|
||||||
|
"""
|
||||||
|
Get details of a specific store creator by username.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: Creator's username
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreatorDetails: Detailed information about the creator
|
||||||
|
"""
|
||||||
|
username = urllib.parse.unquote(username).lower()
|
||||||
|
creator = await store_cache._get_cached_creator_details(username=username)
|
||||||
|
return creator
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,471 +0,0 @@
|
|||||||
"""
|
|
||||||
Security utilities for the integration connect popup flow.
|
|
||||||
|
|
||||||
Handles state management, nonce validation, and origin verification
|
|
||||||
for the OAuth-style popup flow when connecting integrations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
import secrets
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Optional
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from prisma.models import OAuthClient
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# State expiration time
|
|
||||||
STATE_EXPIRATION_SECONDS = 600 # 10 minutes
|
|
||||||
NONCE_EXPIRATION_SECONDS = 3600 # 1 hour (nonces valid for longer to prevent races)
|
|
||||||
LOGIN_STATE_EXPIRATION_SECONDS = 600 # 10 minutes for login redirect flow
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectState(BaseModel):
|
|
||||||
"""Pydantic model for connect state stored in Redis."""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
client_id: str
|
|
||||||
provider: str
|
|
||||||
requested_scopes: list[str]
|
|
||||||
redirect_origin: str
|
|
||||||
nonce: str
|
|
||||||
credential_id: Optional[str] = None
|
|
||||||
created_at: str
|
|
||||||
expires_at: str
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectContinuationState(BaseModel):
|
|
||||||
"""
|
|
||||||
State for continuing the connect flow after OAuth completes.
|
|
||||||
|
|
||||||
When a user chooses to "connect new" during the connect flow,
|
|
||||||
we store this state so we can complete the grant creation after
|
|
||||||
the OAuth callback.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user_id: str
|
|
||||||
client_id: str # Public client ID
|
|
||||||
client_db_id: str # Database UUID of the OAuth client
|
|
||||||
provider: str
|
|
||||||
requested_scopes: list[str] # Integration scopes (e.g., "google:gmail.readonly")
|
|
||||||
redirect_origin: str
|
|
||||||
nonce: str
|
|
||||||
created_at: str
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectLoginState(BaseModel):
|
|
||||||
"""
|
|
||||||
State for connect flow when user needs to log in first.
|
|
||||||
|
|
||||||
When an unauthenticated user tries to access /connect/{provider},
|
|
||||||
we store the connect parameters and redirect to login. After login,
|
|
||||||
the user is redirected back to complete the connect flow.
|
|
||||||
"""
|
|
||||||
|
|
||||||
client_id: str
|
|
||||||
provider: str
|
|
||||||
requested_scopes: list[str]
|
|
||||||
redirect_origin: str
|
|
||||||
nonce: str
|
|
||||||
created_at: str
|
|
||||||
expires_at: str
|
|
||||||
|
|
||||||
|
|
||||||
# Continuation state expiration (same as regular state)
|
|
||||||
CONTINUATION_EXPIRATION_SECONDS = 600 # 10 minutes
|
|
||||||
|
|
||||||
|
|
||||||
async def store_connect_continuation(
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
client_db_id: str,
|
|
||||||
provider: str,
|
|
||||||
requested_scopes: list[str],
|
|
||||||
redirect_origin: str,
|
|
||||||
nonce: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Store continuation state for completing connect flow after OAuth.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User initiating the connection
|
|
||||||
client_id: Public OAuth client ID
|
|
||||||
client_db_id: Database UUID of the OAuth client
|
|
||||||
provider: Integration provider name
|
|
||||||
requested_scopes: Requested integration scopes
|
|
||||||
redirect_origin: Origin to send postMessage to
|
|
||||||
nonce: Client-provided nonce for replay protection
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Continuation token to be stored in OAuth state metadata
|
|
||||||
"""
|
|
||||||
token = generate_connect_token()
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
state = ConnectContinuationState(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
client_db_id=client_db_id,
|
|
||||||
provider=provider,
|
|
||||||
requested_scopes=requested_scopes,
|
|
||||||
redirect_origin=redirect_origin,
|
|
||||||
nonce=nonce,
|
|
||||||
created_at=now.isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_continuation:{token}"
|
|
||||||
await redis.setex(key, CONTINUATION_EXPIRATION_SECONDS, state.model_dump_json())
|
|
||||||
|
|
||||||
logger.debug(f"Stored connect continuation state for token {token[:8]}...")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_connect_continuation(token: str) -> Optional[ConnectContinuationState]:
|
|
||||||
"""
|
|
||||||
Get continuation state without consuming it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Continuation token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectContinuationState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_continuation:{token}"
|
|
||||||
data = await redis.get(key)
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ConnectContinuationState.model_validate_json(data)
|
|
||||||
|
|
||||||
|
|
||||||
async def consume_connect_continuation(
|
|
||||||
token: str,
|
|
||||||
) -> Optional[ConnectContinuationState]:
|
|
||||||
"""
|
|
||||||
Get and consume (delete) continuation state.
|
|
||||||
|
|
||||||
This ensures the token can only be used once.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Continuation token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectContinuationState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_continuation:{token}"
|
|
||||||
|
|
||||||
# Atomic get-and-delete to prevent race conditions
|
|
||||||
data = await redis.getdel(key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
state = ConnectContinuationState.model_validate_json(data)
|
|
||||||
logger.debug(f"Consumed connect continuation state for token {token[:8]}...")
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def generate_connect_token() -> str:
|
|
||||||
"""Generate a secure random token for connect state."""
|
|
||||||
return secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
async def store_connect_state(
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
provider: str,
|
|
||||||
requested_scopes: list[str],
|
|
||||||
redirect_origin: str,
|
|
||||||
nonce: str,
|
|
||||||
credential_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Store connect state in Redis and return a state token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User initiating the connection
|
|
||||||
client_id: OAuth client ID (public identifier)
|
|
||||||
provider: Integration provider name
|
|
||||||
requested_scopes: Requested integration scopes
|
|
||||||
redirect_origin: Origin to send postMessage to
|
|
||||||
nonce: Client-provided nonce for replay protection
|
|
||||||
credential_id: Optional existing credential to grant access to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
State token to be used in the connect flow
|
|
||||||
"""
|
|
||||||
token = generate_connect_token()
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
expires_at = now.timestamp() + STATE_EXPIRATION_SECONDS
|
|
||||||
|
|
||||||
state = ConnectState(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
provider=provider,
|
|
||||||
requested_scopes=requested_scopes,
|
|
||||||
redirect_origin=redirect_origin,
|
|
||||||
nonce=nonce,
|
|
||||||
credential_id=credential_id,
|
|
||||||
created_at=now.isoformat(),
|
|
||||||
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_state:{token}"
|
|
||||||
await redis.setex(key, STATE_EXPIRATION_SECONDS, state.model_dump_json())
|
|
||||||
|
|
||||||
logger.debug(f"Stored connect state for token {token[:8]}...")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_connect_state(token: str) -> Optional[ConnectState]:
|
|
||||||
"""
|
|
||||||
Get connect state without consuming it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: State token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_state:{token}"
|
|
||||||
data = await redis.get(key)
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ConnectState.model_validate_json(data)
|
|
||||||
|
|
||||||
|
|
||||||
async def consume_connect_state(token: str) -> Optional[ConnectState]:
|
|
||||||
"""
|
|
||||||
Get and consume (delete) connect state.
|
|
||||||
|
|
||||||
This ensures the token can only be used once.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: State token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_state:{token}"
|
|
||||||
|
|
||||||
# Atomic get-and-delete to prevent race conditions
|
|
||||||
data = await redis.getdel(key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
state = ConnectState.model_validate_json(data)
|
|
||||||
logger.debug(f"Consumed connect state for token {token[:8]}...")
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
async def store_connect_login_state(
|
|
||||||
client_id: str,
|
|
||||||
provider: str,
|
|
||||||
requested_scopes: list[str],
|
|
||||||
redirect_origin: str,
|
|
||||||
nonce: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Store connect parameters for unauthenticated users.
|
|
||||||
|
|
||||||
When a user isn't logged in, we store the connect params and redirect
|
|
||||||
to login. After login, the frontend calls /connect/resume with the token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: OAuth client ID
|
|
||||||
provider: Integration provider name
|
|
||||||
requested_scopes: Requested integration scopes
|
|
||||||
redirect_origin: Origin to send postMessage to
|
|
||||||
nonce: Client-provided nonce for replay protection
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Login state token to be used after login completes
|
|
||||||
"""
|
|
||||||
token = generate_connect_token()
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
expires_at = now.timestamp() + LOGIN_STATE_EXPIRATION_SECONDS
|
|
||||||
|
|
||||||
state = ConnectLoginState(
|
|
||||||
client_id=client_id,
|
|
||||||
provider=provider,
|
|
||||||
requested_scopes=requested_scopes,
|
|
||||||
redirect_origin=redirect_origin,
|
|
||||||
nonce=nonce,
|
|
||||||
created_at=now.isoformat(),
|
|
||||||
expires_at=datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_login_state:{token}"
|
|
||||||
await redis.setex(key, LOGIN_STATE_EXPIRATION_SECONDS, state.model_dump_json())
|
|
||||||
|
|
||||||
logger.debug(f"Stored connect login state for token {token[:8]}...")
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_connect_login_state(token: str) -> Optional[ConnectLoginState]:
|
|
||||||
"""
|
|
||||||
Get connect login state without consuming it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Login state token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectLoginState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_login_state:{token}"
|
|
||||||
data = await redis.get(key)
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ConnectLoginState.model_validate_json(data)
|
|
||||||
|
|
||||||
|
|
||||||
async def consume_connect_login_state(token: str) -> Optional[ConnectLoginState]:
|
|
||||||
"""
|
|
||||||
Get and consume (delete) connect login state.
|
|
||||||
|
|
||||||
This ensures the token can only be used once.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Login state token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConnectLoginState or None if not found/expired
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"connect_login_state:{token}"
|
|
||||||
|
|
||||||
# Atomic get-and-delete to prevent race conditions
|
|
||||||
data = await redis.getdel(key)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
state = ConnectLoginState.model_validate_json(data)
|
|
||||||
logger.debug(f"Consumed connect login state for token {token[:8]}...")
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_nonce(client_id: str, nonce: str) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that a nonce hasn't been used before (replay protection).
|
|
||||||
|
|
||||||
Uses atomic SET NX EX for check-and-set with automatic TTL expiry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: OAuth client ID
|
|
||||||
nonce: Client-provided nonce
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if nonce is valid (not replayed)
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Create a hash of the nonce for storage
|
|
||||||
nonce_hash = hashlib.sha256(nonce.encode()).hexdigest()
|
|
||||||
key = f"nonce:{client_id}:{nonce_hash}"
|
|
||||||
|
|
||||||
# Atomic set-if-not-exists with expiration (prevents race condition)
|
|
||||||
was_set = await redis.set(key, "1", nx=True, ex=NONCE_EXPIRATION_SECONDS)
|
|
||||||
if was_set:
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.warning(f"Nonce replay detected for client {client_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_redirect_origin(origin: str, client: OAuthClient) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that a redirect origin is allowed for the client.
|
|
||||||
|
|
||||||
The origin must match one of the client's registered redirect URIs
|
|
||||||
or webhook domains.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
origin: Origin URL to validate
|
|
||||||
client: OAuth client to check against
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if origin is allowed
|
|
||||||
"""
|
|
||||||
from backend.util.url import hostname_matches_any_domain
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed_origin = urlparse(origin)
|
|
||||||
origin_host = parsed_origin.netloc.lower()
|
|
||||||
|
|
||||||
# Check against redirect URIs
|
|
||||||
for redirect_uri in client.redirectUris:
|
|
||||||
parsed_redirect = urlparse(redirect_uri)
|
|
||||||
if parsed_redirect.netloc.lower() == origin_host:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check against webhook domains
|
|
||||||
if hostname_matches_any_domain(origin_host, client.webhookDomains):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def create_post_message_data(
|
|
||||||
success: bool,
|
|
||||||
grant_id: Optional[str] = None,
|
|
||||||
credential_id: Optional[str] = None,
|
|
||||||
provider: Optional[str] = None,
|
|
||||||
error: Optional[str] = None,
|
|
||||||
error_description: Optional[str] = None,
|
|
||||||
nonce: Optional[str] = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Create the postMessage data to send back to the opener.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
success: Whether the operation succeeded
|
|
||||||
grant_id: ID of the created grant (if successful)
|
|
||||||
credential_id: ID of the credential (if successful)
|
|
||||||
provider: Provider name
|
|
||||||
error: Error code (if failed)
|
|
||||||
error_description: Human-readable error description
|
|
||||||
nonce: Original nonce for correlation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary to be sent via postMessage
|
|
||||||
"""
|
|
||||||
data: dict[str, Any] = {
|
|
||||||
"type": "autogpt_connect_result",
|
|
||||||
"success": success,
|
|
||||||
}
|
|
||||||
|
|
||||||
if nonce:
|
|
||||||
data["nonce"] = nonce
|
|
||||||
|
|
||||||
if success:
|
|
||||||
data["grant_id"] = grant_id
|
|
||||||
data["credential_id"] = credential_id
|
|
||||||
data["provider"] = provider
|
|
||||||
else:
|
|
||||||
data["error"] = error
|
|
||||||
data["error_description"] = error_description
|
|
||||||
|
|
||||||
return data
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth 2.0 Provider module for AutoGPT Platform.
|
|
||||||
|
|
||||||
This module implements AutoGPT as an OAuth 2.0 Authorization Server,
|
|
||||||
allowing external applications to authenticate users and access
|
|
||||||
platform resources with user consent.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- router.py: OAuth authorization and token endpoints
|
|
||||||
- discovery_router.py: OIDC discovery endpoints
|
|
||||||
- client_router.py: OAuth client management
|
|
||||||
- token_service.py: JWT generation and validation
|
|
||||||
- service.py: Core OAuth business logic
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.server.oauth.client_router import client_router
|
|
||||||
from backend.server.oauth.discovery_router import discovery_router
|
|
||||||
from backend.server.oauth.router import oauth_router
|
|
||||||
|
|
||||||
__all__ = ["oauth_router", "discovery_router", "client_router"]
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth Client Management endpoints.
|
|
||||||
|
|
||||||
Implements self-service client registration and management:
|
|
||||||
- POST /oauth/clients - Register a new client
|
|
||||||
- GET /oauth/clients - List owned clients
|
|
||||||
- GET /oauth/clients/{client_id} - Get client details
|
|
||||||
- PATCH /oauth/clients/{client_id} - Update client
|
|
||||||
- DELETE /oauth/clients/{client_id} - Delete client
|
|
||||||
- POST /oauth/clients/{client_id}/rotate-secret - Rotate client secret
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
from fastapi import APIRouter, HTTPException, Security
|
|
||||||
from prisma.enums import OAuthClientStatus
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.server.oauth.models import (
|
|
||||||
ClientResponse,
|
|
||||||
ClientSecretResponse,
|
|
||||||
OAuthScope,
|
|
||||||
RegisterClientRequest,
|
|
||||||
UpdateClientRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
client_router = APIRouter(prefix="/oauth/clients", tags=["oauth-clients"])
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_client_id() -> str:
|
|
||||||
"""Generate a unique client ID."""
|
|
||||||
return f"app_{secrets.token_urlsafe(16)}"
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_client_secret() -> str:
|
|
||||||
"""Generate a secure client secret."""
|
|
||||||
return secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_webhook_secret() -> str:
|
|
||||||
"""Generate a secure webhook secret for HMAC signing."""
|
|
||||||
return secrets.token_urlsafe(32)
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_secret(secret: str, salt: str) -> str:
|
|
||||||
"""Hash a client secret with salt."""
|
|
||||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _client_to_response(client) -> ClientResponse:
|
|
||||||
"""Convert Prisma client to response model."""
|
|
||||||
return ClientResponse(
|
|
||||||
id=client.id,
|
|
||||||
client_id=client.clientId,
|
|
||||||
client_type=client.clientType,
|
|
||||||
name=client.name,
|
|
||||||
description=client.description,
|
|
||||||
logo_url=client.logoUrl,
|
|
||||||
homepage_url=client.homepageUrl,
|
|
||||||
privacy_policy_url=client.privacyPolicyUrl,
|
|
||||||
terms_of_service_url=client.termsOfServiceUrl,
|
|
||||||
redirect_uris=client.redirectUris,
|
|
||||||
allowed_scopes=client.allowedScopes,
|
|
||||||
webhook_domains=client.webhookDomains,
|
|
||||||
status=client.status,
|
|
||||||
created_at=client.createdAt,
|
|
||||||
updated_at=client.updatedAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Default allowed scopes for new clients
|
|
||||||
DEFAULT_ALLOWED_SCOPES = [
|
|
||||||
OAuthScope.OPENID.value,
|
|
||||||
OAuthScope.PROFILE.value,
|
|
||||||
OAuthScope.EMAIL.value,
|
|
||||||
OAuthScope.INTEGRATIONS_LIST.value,
|
|
||||||
OAuthScope.INTEGRATIONS_CONNECT.value,
|
|
||||||
OAuthScope.INTEGRATIONS_DELETE.value,
|
|
||||||
OAuthScope.AGENTS_EXECUTE.value,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.post("/", response_model=ClientSecretResponse)
|
|
||||||
async def register_client(
|
|
||||||
request: RegisterClientRequest,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientSecretResponse:
|
|
||||||
"""
|
|
||||||
Register a new OAuth client.
|
|
||||||
|
|
||||||
The client is immediately active (no admin approval required).
|
|
||||||
For confidential clients, the client_secret is returned only once.
|
|
||||||
The webhook_secret is always generated and returned only once.
|
|
||||||
"""
|
|
||||||
# Generate client credentials
|
|
||||||
client_id = _generate_client_id()
|
|
||||||
client_secret = None
|
|
||||||
client_secret_hash = None
|
|
||||||
client_secret_salt = None
|
|
||||||
|
|
||||||
if request.client_type == "confidential":
|
|
||||||
client_secret = _generate_client_secret()
|
|
||||||
client_secret_salt = secrets.token_urlsafe(16)
|
|
||||||
client_secret_hash = _hash_secret(client_secret, client_secret_salt)
|
|
||||||
|
|
||||||
# Generate webhook secret for HMAC signing
|
|
||||||
webhook_secret = _generate_webhook_secret()
|
|
||||||
|
|
||||||
# Create client
|
|
||||||
await prisma.oauthclient.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"clientId": client_id,
|
|
||||||
"clientSecretHash": client_secret_hash,
|
|
||||||
"clientSecretSalt": client_secret_salt,
|
|
||||||
"clientType": request.client_type,
|
|
||||||
"name": request.name,
|
|
||||||
"description": request.description,
|
|
||||||
"logoUrl": str(request.logo_url) if request.logo_url else None,
|
|
||||||
"homepageUrl": str(request.homepage_url) if request.homepage_url else None,
|
|
||||||
"privacyPolicyUrl": (
|
|
||||||
str(request.privacy_policy_url) if request.privacy_policy_url else None
|
|
||||||
),
|
|
||||||
"termsOfServiceUrl": (
|
|
||||||
str(request.terms_of_service_url)
|
|
||||||
if request.terms_of_service_url
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
"redirectUris": request.redirect_uris,
|
|
||||||
"allowedScopes": DEFAULT_ALLOWED_SCOPES,
|
|
||||||
"webhookDomains": request.webhook_domains,
|
|
||||||
"webhookSecret": webhook_secret,
|
|
||||||
"status": OAuthClientStatus.ACTIVE,
|
|
||||||
"ownerId": user_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClientSecretResponse(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret or "",
|
|
||||||
webhook_secret=webhook_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.get("/", response_model=list[ClientResponse])
|
|
||||||
async def list_clients(
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> list[ClientResponse]:
|
|
||||||
"""List all OAuth clients owned by the current user."""
|
|
||||||
clients = await prisma.oauthclient.find_many(
|
|
||||||
where={"ownerId": user_id},
|
|
||||||
order={"createdAt": "desc"},
|
|
||||||
)
|
|
||||||
return [_client_to_response(c) for c in clients]
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.get("/{client_id}", response_model=ClientResponse)
|
|
||||||
async def get_client(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientResponse:
|
|
||||||
"""Get details of a specific OAuth client."""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
return _client_to_response(client)
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.patch("/{client_id}", response_model=ClientResponse)
|
|
||||||
async def update_client(
|
|
||||||
client_id: str,
|
|
||||||
request: UpdateClientRequest,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientResponse:
|
|
||||||
"""Update an OAuth client."""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
# Build update data
|
|
||||||
update_data: dict[str, str | list[str] | None] = {}
|
|
||||||
if request.name is not None:
|
|
||||||
update_data["name"] = request.name
|
|
||||||
if request.description is not None:
|
|
||||||
update_data["description"] = request.description
|
|
||||||
if request.logo_url is not None:
|
|
||||||
update_data["logoUrl"] = str(request.logo_url)
|
|
||||||
if request.homepage_url is not None:
|
|
||||||
update_data["homepageUrl"] = str(request.homepage_url)
|
|
||||||
if request.privacy_policy_url is not None:
|
|
||||||
update_data["privacyPolicyUrl"] = str(request.privacy_policy_url)
|
|
||||||
if request.terms_of_service_url is not None:
|
|
||||||
update_data["termsOfServiceUrl"] = str(request.terms_of_service_url)
|
|
||||||
if request.redirect_uris is not None:
|
|
||||||
update_data["redirectUris"] = request.redirect_uris
|
|
||||||
if request.webhook_domains is not None:
|
|
||||||
update_data["webhookDomains"] = request.webhook_domains
|
|
||||||
|
|
||||||
if not update_data:
|
|
||||||
return _client_to_response(client)
|
|
||||||
|
|
||||||
updated = await prisma.oauthclient.update(
|
|
||||||
where={"id": client.id},
|
|
||||||
data=update_data, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
return _client_to_response(updated)
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.delete("/{client_id}")
|
|
||||||
async def delete_client(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Delete an OAuth client.
|
|
||||||
|
|
||||||
This will also revoke all tokens and authorizations for this client.
|
|
||||||
"""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
# Delete cascades will handle tokens, codes, and authorizations
|
|
||||||
await prisma.oauthclient.delete(where={"id": client.id})
|
|
||||||
|
|
||||||
return {"status": "deleted", "client_id": client_id}
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.post("/{client_id}/rotate-secret", response_model=ClientSecretResponse)
|
|
||||||
async def rotate_client_secret(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientSecretResponse:
|
|
||||||
"""
|
|
||||||
Rotate the client secret for a confidential client.
|
|
||||||
|
|
||||||
The new secret is returned only once. All existing tokens remain valid.
|
|
||||||
Also rotates the webhook secret for security.
|
|
||||||
"""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
if client.clientType != "confidential":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Cannot rotate secret for public clients",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate new secrets
|
|
||||||
new_secret = _generate_client_secret()
|
|
||||||
new_salt = secrets.token_urlsafe(16)
|
|
||||||
new_hash = _hash_secret(new_secret, new_salt)
|
|
||||||
new_webhook_secret = _generate_webhook_secret()
|
|
||||||
|
|
||||||
await prisma.oauthclient.update(
|
|
||||||
where={"id": client.id},
|
|
||||||
data={
|
|
||||||
"clientSecretHash": new_hash,
|
|
||||||
"clientSecretSalt": new_salt,
|
|
||||||
"webhookSecret": new_webhook_secret,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClientSecretResponse(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=new_secret,
|
|
||||||
webhook_secret=new_webhook_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WebhookSecretResponse(BaseModel):
|
|
||||||
"""Response containing newly generated webhook secret."""
|
|
||||||
|
|
||||||
client_id: str
|
|
||||||
webhook_secret: str
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.post(
|
|
||||||
"/{client_id}/rotate-webhook-secret", response_model=WebhookSecretResponse
|
|
||||||
)
|
|
||||||
async def rotate_webhook_secret(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> WebhookSecretResponse:
|
|
||||||
"""
|
|
||||||
Rotate only the webhook secret for a client.
|
|
||||||
|
|
||||||
The new webhook secret is returned only once.
|
|
||||||
"""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
# Generate new webhook secret
|
|
||||||
new_webhook_secret = _generate_webhook_secret()
|
|
||||||
|
|
||||||
await prisma.oauthclient.update(
|
|
||||||
where={"id": client.id},
|
|
||||||
data={"webhookSecret": new_webhook_secret},
|
|
||||||
)
|
|
||||||
|
|
||||||
return WebhookSecretResponse(
|
|
||||||
client_id=client_id,
|
|
||||||
webhook_secret=new_webhook_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.post("/{client_id}/suspend")
|
|
||||||
async def suspend_client(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientResponse:
|
|
||||||
"""Suspend an OAuth client (prevents new authorizations)."""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
updated = await prisma.oauthclient.update(
|
|
||||||
where={"id": client.id},
|
|
||||||
data={"status": OAuthClientStatus.SUSPENDED},
|
|
||||||
)
|
|
||||||
|
|
||||||
return _client_to_response(updated)
|
|
||||||
|
|
||||||
|
|
||||||
@client_router.post("/{client_id}/activate")
|
|
||||||
async def activate_client(
|
|
||||||
client_id: str,
|
|
||||||
user_id: str = Security(get_user_id),
|
|
||||||
) -> ClientResponse:
|
|
||||||
"""Reactivate a suspended OAuth client."""
|
|
||||||
client = await prisma.oauthclient.find_first(
|
|
||||||
where={"clientId": client_id, "ownerId": user_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(status_code=404, detail="Client not found")
|
|
||||||
|
|
||||||
updated = await prisma.oauthclient.update(
|
|
||||||
where={"id": client.id},
|
|
||||||
data={"status": OAuthClientStatus.ACTIVE},
|
|
||||||
)
|
|
||||||
|
|
||||||
return _client_to_response(updated)
|
|
||||||
@@ -1,678 +0,0 @@
|
|||||||
"""
|
|
||||||
Server-rendered HTML templates for OAuth consent UI.
|
|
||||||
|
|
||||||
These templates are used for the OAuth authorization flow
|
|
||||||
when the user needs to approve access for an external application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import html
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
|
|
||||||
|
|
||||||
|
|
||||||
def _base_styles() -> str:
|
|
||||||
"""Common CSS styles for all OAuth pages."""
|
|
||||||
return """
|
|
||||||
* {
|
|
||||||
box-sizing: border-box;
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
body {
|
|
||||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
|
||||||
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
|
||||||
min-height: 100vh;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
padding: 20px;
|
|
||||||
color: #e4e4e7;
|
|
||||||
}
|
|
||||||
.container {
|
|
||||||
background: #27272a;
|
|
||||||
border-radius: 16px;
|
|
||||||
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
|
||||||
max-width: 420px;
|
|
||||||
width: 100%;
|
|
||||||
padding: 32px;
|
|
||||||
}
|
|
||||||
.header {
|
|
||||||
text-align: center;
|
|
||||||
margin-bottom: 24px;
|
|
||||||
}
|
|
||||||
.logo {
|
|
||||||
width: 64px;
|
|
||||||
height: 64px;
|
|
||||||
border-radius: 12px;
|
|
||||||
margin-bottom: 16px;
|
|
||||||
background: #3f3f46;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
margin-left: auto;
|
|
||||||
margin-right: auto;
|
|
||||||
}
|
|
||||||
.logo img {
|
|
||||||
max-width: 48px;
|
|
||||||
max-height: 48px;
|
|
||||||
border-radius: 8px;
|
|
||||||
}
|
|
||||||
.logo-placeholder {
|
|
||||||
font-size: 28px;
|
|
||||||
color: #a1a1aa;
|
|
||||||
}
|
|
||||||
h1 {
|
|
||||||
font-size: 20px;
|
|
||||||
font-weight: 600;
|
|
||||||
margin-bottom: 8px;
|
|
||||||
}
|
|
||||||
.subtitle {
|
|
||||||
color: #a1a1aa;
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
.app-name {
|
|
||||||
color: #22d3ee;
|
|
||||||
font-weight: 600;
|
|
||||||
}
|
|
||||||
.divider {
|
|
||||||
height: 1px;
|
|
||||||
background: #3f3f46;
|
|
||||||
margin: 24px 0;
|
|
||||||
}
|
|
||||||
.scopes-section h2 {
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
color: #a1a1aa;
|
|
||||||
margin-bottom: 16px;
|
|
||||||
}
|
|
||||||
.scope-item {
|
|
||||||
display: flex;
|
|
||||||
align-items: flex-start;
|
|
||||||
gap: 12px;
|
|
||||||
padding: 12px 0;
|
|
||||||
border-bottom: 1px solid #3f3f46;
|
|
||||||
}
|
|
||||||
.scope-item:last-child {
|
|
||||||
border-bottom: none;
|
|
||||||
}
|
|
||||||
.scope-icon {
|
|
||||||
width: 20px;
|
|
||||||
height: 20px;
|
|
||||||
color: #22d3ee;
|
|
||||||
flex-shrink: 0;
|
|
||||||
margin-top: 2px;
|
|
||||||
}
|
|
||||||
.scope-text {
|
|
||||||
font-size: 14px;
|
|
||||||
line-height: 1.5;
|
|
||||||
}
|
|
||||||
.buttons {
|
|
||||||
display: flex;
|
|
||||||
gap: 12px;
|
|
||||||
margin-top: 24px;
|
|
||||||
}
|
|
||||||
.btn {
|
|
||||||
flex: 1;
|
|
||||||
padding: 12px 24px;
|
|
||||||
border-radius: 8px;
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
cursor: pointer;
|
|
||||||
border: none;
|
|
||||||
transition: all 0.2s;
|
|
||||||
}
|
|
||||||
.btn-cancel {
|
|
||||||
background: #3f3f46;
|
|
||||||
color: #e4e4e7;
|
|
||||||
}
|
|
||||||
.btn-cancel:hover {
|
|
||||||
background: #52525b;
|
|
||||||
}
|
|
||||||
.btn-allow {
|
|
||||||
background: #22d3ee;
|
|
||||||
color: #0f172a;
|
|
||||||
}
|
|
||||||
.btn-allow:hover {
|
|
||||||
background: #06b6d4;
|
|
||||||
}
|
|
||||||
.footer {
|
|
||||||
margin-top: 24px;
|
|
||||||
text-align: center;
|
|
||||||
font-size: 12px;
|
|
||||||
color: #71717a;
|
|
||||||
}
|
|
||||||
.footer a {
|
|
||||||
color: #a1a1aa;
|
|
||||||
text-decoration: none;
|
|
||||||
}
|
|
||||||
.footer a:hover {
|
|
||||||
text-decoration: underline;
|
|
||||||
}
|
|
||||||
.error-container {
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
.error-icon {
|
|
||||||
width: 64px;
|
|
||||||
height: 64px;
|
|
||||||
margin: 0 auto 16px;
|
|
||||||
color: #ef4444;
|
|
||||||
}
|
|
||||||
.error-title {
|
|
||||||
color: #ef4444;
|
|
||||||
font-size: 18px;
|
|
||||||
font-weight: 600;
|
|
||||||
margin-bottom: 8px;
|
|
||||||
}
|
|
||||||
.error-message {
|
|
||||||
color: #a1a1aa;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 24px;
|
|
||||||
}
|
|
||||||
.success-icon {
|
|
||||||
width: 64px;
|
|
||||||
height: 64px;
|
|
||||||
margin: 0 auto 16px;
|
|
||||||
color: #22c55e;
|
|
||||||
}
|
|
||||||
.success-title {
|
|
||||||
color: #22c55e;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _check_icon() -> str:
|
|
||||||
"""SVG checkmark icon."""
|
|
||||||
return """
|
|
||||||
<svg class="scope-icon" viewBox="0 0 20 20" fill="currentColor">
|
|
||||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/>
|
|
||||||
</svg>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _error_icon() -> str:
|
|
||||||
"""SVG error icon."""
|
|
||||||
return """
|
|
||||||
<svg class="error-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
|
||||||
<circle cx="12" cy="12" r="10"/>
|
|
||||||
<line x1="15" y1="9" x2="9" y2="15"/>
|
|
||||||
<line x1="9" y1="9" x2="15" y2="15"/>
|
|
||||||
</svg>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _success_icon() -> str:
|
|
||||||
"""SVG success icon."""
|
|
||||||
return """
|
|
||||||
<svg class="success-icon" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
|
||||||
<circle cx="12" cy="12" r="10"/>
|
|
||||||
<path d="M9 12l2 2 4-4"/>
|
|
||||||
</svg>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def render_consent_page(
|
|
||||||
client_name: str,
|
|
||||||
client_logo: Optional[str],
|
|
||||||
scopes: list[str],
|
|
||||||
consent_token: str,
|
|
||||||
action_url: str,
|
|
||||||
privacy_policy_url: Optional[str] = None,
|
|
||||||
terms_url: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Render the OAuth consent page.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_name: Name of the requesting application
|
|
||||||
client_logo: URL to the client's logo (optional)
|
|
||||||
scopes: List of requested scopes
|
|
||||||
consent_token: CSRF token for the consent form
|
|
||||||
action_url: URL to submit the consent form
|
|
||||||
privacy_policy_url: Client's privacy policy URL (optional)
|
|
||||||
terms_url: Client's terms of service URL (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML string for the consent page
|
|
||||||
"""
|
|
||||||
# Escape user-provided values to prevent XSS
|
|
||||||
safe_client_name = html.escape(client_name)
|
|
||||||
safe_client_logo = html.escape(client_logo) if client_logo else None
|
|
||||||
|
|
||||||
# Build logo HTML
|
|
||||||
if safe_client_logo:
|
|
||||||
logo_html = f'<img src="{safe_client_logo}" alt="{safe_client_name}">'
|
|
||||||
else:
|
|
||||||
logo_html = f'<span class="logo-placeholder">{html.escape(client_name[0].upper())}</span>'
|
|
||||||
|
|
||||||
# Build scopes HTML
|
|
||||||
scopes_html = ""
|
|
||||||
for scope in scopes:
|
|
||||||
description = SCOPE_DESCRIPTIONS.get(scope, scope)
|
|
||||||
scopes_html += f"""
|
|
||||||
<div class="scope-item">
|
|
||||||
{_check_icon()}
|
|
||||||
<span class="scope-text">{html.escape(description)}</span>
|
|
||||||
</div>
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Build footer links (escape URLs)
|
|
||||||
footer_links = []
|
|
||||||
if privacy_policy_url:
|
|
||||||
footer_links.append(
|
|
||||||
f'<a href="{html.escape(privacy_policy_url)}" target="_blank">Privacy Policy</a>'
|
|
||||||
)
|
|
||||||
if terms_url:
|
|
||||||
footer_links.append(
|
|
||||||
f'<a href="{html.escape(terms_url)}" target="_blank">Terms of Service</a>'
|
|
||||||
)
|
|
||||||
footer_html = " • ".join(footer_links) if footer_links else ""
|
|
||||||
|
|
||||||
# Escape action_url and consent_token
|
|
||||||
safe_action_url = html.escape(action_url)
|
|
||||||
safe_consent_token = html.escape(consent_token)
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Authorize {safe_client_name} - AutoGPT</title>
|
|
||||||
<style>{_base_styles()}</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="header">
|
|
||||||
<div class="logo">{logo_html}</div>
|
|
||||||
<h1>Authorize <span class="app-name">{safe_client_name}</span></h1>
|
|
||||||
<p class="subtitle">wants to access your AutoGPT account</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="divider"></div>
|
|
||||||
|
|
||||||
<div class="scopes-section">
|
|
||||||
<h2>This will allow {safe_client_name} to:</h2>
|
|
||||||
{scopes_html}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<form method="POST" action="{safe_action_url}">
|
|
||||||
<input type="hidden" name="consent_token" value="{safe_consent_token}">
|
|
||||||
<div class="buttons">
|
|
||||||
<button type="submit" name="authorize" value="false" class="btn btn-cancel">
|
|
||||||
Cancel
|
|
||||||
</button>
|
|
||||||
<button type="submit" name="authorize" value="true" class="btn btn-allow">
|
|
||||||
Allow
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
|
|
||||||
{f'<div class="footer">{footer_html}</div>' if footer_html else ''}
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def render_error_page(
|
|
||||||
error: str,
|
|
||||||
error_description: str,
|
|
||||||
redirect_url: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Render an OAuth error page.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error: Error code
|
|
||||||
error_description: Human-readable error description
|
|
||||||
redirect_url: Optional URL to redirect back (if safe)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML string for the error page
|
|
||||||
"""
|
|
||||||
# Escape user-provided values to prevent XSS
|
|
||||||
safe_error = html.escape(error)
|
|
||||||
safe_error_description = html.escape(error_description)
|
|
||||||
|
|
||||||
redirect_html = ""
|
|
||||||
if redirect_url:
|
|
||||||
safe_redirect_url = html.escape(redirect_url)
|
|
||||||
redirect_html = f"""
|
|
||||||
<a href="{safe_redirect_url}" class="btn btn-cancel" style="display: inline-block; text-decoration: none;">
|
|
||||||
Go Back
|
|
||||||
</a>
|
|
||||||
"""
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Authorization Error - AutoGPT</title>
|
|
||||||
<style>{_base_styles()}</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="error-container">
|
|
||||||
{_error_icon()}
|
|
||||||
<h1 class="error-title">Authorization Failed</h1>
|
|
||||||
<p class="error-message">{safe_error_description}</p>
|
|
||||||
<p class="error-message" style="font-size: 12px; color: #52525b;">
|
|
||||||
Error code: {safe_error}
|
|
||||||
</p>
|
|
||||||
{redirect_html}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def render_success_page(
|
|
||||||
message: str,
|
|
||||||
redirect_origin: Optional[str] = None,
|
|
||||||
post_message_data: Optional[dict] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Render a success page, optionally with postMessage for popup flows.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: Success message to display
|
|
||||||
redirect_origin: Origin for postMessage (popup flows)
|
|
||||||
post_message_data: Data to send via postMessage (popup flows)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML string for the success page
|
|
||||||
"""
|
|
||||||
# Escape user-provided values to prevent XSS
|
|
||||||
safe_message = html.escape(message)
|
|
||||||
|
|
||||||
# PostMessage script for popup flows
|
|
||||||
post_message_script = ""
|
|
||||||
if redirect_origin and post_message_data:
|
|
||||||
import json
|
|
||||||
|
|
||||||
# json.dumps escapes for JS context, but we also escape < > for HTML context
|
|
||||||
safe_json_origin = (
|
|
||||||
json.dumps(redirect_origin).replace("<", "\\u003c").replace(">", "\\u003e")
|
|
||||||
)
|
|
||||||
safe_json_data = (
|
|
||||||
json.dumps(post_message_data)
|
|
||||||
.replace("<", "\\u003c")
|
|
||||||
.replace(">", "\\u003e")
|
|
||||||
)
|
|
||||||
|
|
||||||
post_message_script = f"""
|
|
||||||
<script>
|
|
||||||
(function() {{
|
|
||||||
var targetOrigin = {safe_json_origin};
|
|
||||||
var message = {safe_json_data};
|
|
||||||
if (window.opener) {{
|
|
||||||
window.opener.postMessage(message, targetOrigin);
|
|
||||||
setTimeout(function() {{ window.close(); }}, 1000);
|
|
||||||
}}
|
|
||||||
}})();
|
|
||||||
</script>
|
|
||||||
"""
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Authorization Successful - AutoGPT</title>
|
|
||||||
<style>{_base_styles()}</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="error-container">
|
|
||||||
{_success_icon()}
|
|
||||||
<h1 class="success-title">Success!</h1>
|
|
||||||
<p class="error-message">{safe_message}</p>
|
|
||||||
<p class="error-message" style="font-size: 12px;">
|
|
||||||
This window will close automatically...
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{post_message_script}
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def render_login_redirect_page(login_url: str) -> str:
|
|
||||||
"""
|
|
||||||
Render a page that redirects to login.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
login_url: URL to redirect to for login
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML string with auto-redirect
|
|
||||||
"""
|
|
||||||
# Escape URL to prevent XSS
|
|
||||||
safe_login_url = html.escape(login_url)
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<meta http-equiv="refresh" content="0;url={safe_login_url}">
|
|
||||||
<title>Login Required - AutoGPT</title>
|
|
||||||
<style>{_base_styles()}</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="error-container">
|
|
||||||
<p class="error-message">Redirecting to login...</p>
|
|
||||||
<a href="{safe_login_url}" class="btn btn-allow" style="display: inline-block; text-decoration: none;">
|
|
||||||
Click here if not redirected
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _login_form_styles() -> str:
|
|
||||||
"""Additional CSS styles for login form."""
|
|
||||||
return """
|
|
||||||
.form-group {
|
|
||||||
margin-bottom: 16px;
|
|
||||||
}
|
|
||||||
.form-group label {
|
|
||||||
display: block;
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
color: #a1a1aa;
|
|
||||||
margin-bottom: 8px;
|
|
||||||
}
|
|
||||||
.form-group input {
|
|
||||||
width: 100%;
|
|
||||||
padding: 12px 16px;
|
|
||||||
border-radius: 8px;
|
|
||||||
border: 1px solid #3f3f46;
|
|
||||||
background: #18181b;
|
|
||||||
color: #e4e4e7;
|
|
||||||
font-size: 14px;
|
|
||||||
outline: none;
|
|
||||||
transition: border-color 0.2s;
|
|
||||||
}
|
|
||||||
.form-group input:focus {
|
|
||||||
border-color: #22d3ee;
|
|
||||||
}
|
|
||||||
.form-group input::placeholder {
|
|
||||||
color: #52525b;
|
|
||||||
}
|
|
||||||
.error-alert {
|
|
||||||
background: rgba(239, 68, 68, 0.1);
|
|
||||||
border: 1px solid #ef4444;
|
|
||||||
border-radius: 8px;
|
|
||||||
padding: 12px 16px;
|
|
||||||
margin-bottom: 16px;
|
|
||||||
color: #fca5a5;
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
.btn-login {
|
|
||||||
width: 100%;
|
|
||||||
padding: 12px 24px;
|
|
||||||
border-radius: 8px;
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
cursor: pointer;
|
|
||||||
border: none;
|
|
||||||
background: #22d3ee;
|
|
||||||
color: #0f172a;
|
|
||||||
transition: all 0.2s;
|
|
||||||
margin-top: 8px;
|
|
||||||
}
|
|
||||||
.btn-login:hover {
|
|
||||||
background: #06b6d4;
|
|
||||||
}
|
|
||||||
.btn-login:disabled {
|
|
||||||
background: #3f3f46;
|
|
||||||
color: #71717a;
|
|
||||||
cursor: not-allowed;
|
|
||||||
}
|
|
||||||
.signup-link {
|
|
||||||
text-align: center;
|
|
||||||
margin-top: 16px;
|
|
||||||
font-size: 14px;
|
|
||||||
color: #a1a1aa;
|
|
||||||
}
|
|
||||||
.signup-link a {
|
|
||||||
color: #22d3ee;
|
|
||||||
text-decoration: none;
|
|
||||||
}
|
|
||||||
.signup-link a:hover {
|
|
||||||
text-decoration: underline;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def render_login_page(
|
|
||||||
action_url: str,
|
|
||||||
login_state: str,
|
|
||||||
client_name: Optional[str] = None,
|
|
||||||
error_message: Optional[str] = None,
|
|
||||||
signup_url: Optional[str] = None,
|
|
||||||
browser_login_url: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Render an embedded login page for OAuth flow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_url: URL to submit the login form
|
|
||||||
login_state: State token to preserve OAuth parameters
|
|
||||||
client_name: Name of the application requesting access (optional)
|
|
||||||
error_message: Error message to display (optional)
|
|
||||||
signup_url: URL to signup page (optional)
|
|
||||||
browser_login_url: URL to redirect to frontend login (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML string for the login page
|
|
||||||
"""
|
|
||||||
# Escape all user-provided values to prevent XSS
|
|
||||||
safe_action_url = html.escape(action_url)
|
|
||||||
safe_login_state = html.escape(login_state)
|
|
||||||
safe_client_name = html.escape(client_name) if client_name else None
|
|
||||||
|
|
||||||
error_html = ""
|
|
||||||
if error_message:
|
|
||||||
safe_error_message = html.escape(error_message)
|
|
||||||
error_html = f'<div class="error-alert">{safe_error_message}</div>'
|
|
||||||
|
|
||||||
subtitle = "wants to access your AutoGPT account" if safe_client_name else ""
|
|
||||||
title_html = (
|
|
||||||
'<h1>Sign in to <span class="app-name">AutoGPT</span></h1>'
|
|
||||||
if not safe_client_name
|
|
||||||
else f'<h1><span class="app-name">{safe_client_name}</span></h1>'
|
|
||||||
)
|
|
||||||
|
|
||||||
signup_html = ""
|
|
||||||
if signup_url:
|
|
||||||
safe_signup_url = html.escape(signup_url)
|
|
||||||
signup_html = f"""
|
|
||||||
<div class="signup-link">
|
|
||||||
Don't have an account? <a href="{safe_signup_url}">Sign up</a>
|
|
||||||
</div>
|
|
||||||
"""
|
|
||||||
|
|
||||||
browser_login_html = ""
|
|
||||||
if browser_login_url:
|
|
||||||
safe_browser_login_url = html.escape(browser_login_url)
|
|
||||||
browser_login_html = f"""
|
|
||||||
<div class="divider"></div>
|
|
||||||
<div class="signup-link">
|
|
||||||
<a href="{safe_browser_login_url}">Sign in with Google or other providers</a>
|
|
||||||
</div>
|
|
||||||
"""
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Sign In - AutoGPT</title>
|
|
||||||
<style>
|
|
||||||
{_base_styles()}
|
|
||||||
{_login_form_styles()}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="header">
|
|
||||||
<div class="logo">
|
|
||||||
<span class="logo-placeholder">A</span>
|
|
||||||
</div>
|
|
||||||
{title_html}
|
|
||||||
<p class="subtitle">{subtitle}</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="divider"></div>
|
|
||||||
|
|
||||||
{error_html}
|
|
||||||
|
|
||||||
<form method="POST" action="{safe_action_url}">
|
|
||||||
<input type="hidden" name="login_state" value="{safe_login_state}">
|
|
||||||
|
|
||||||
<div class="form-group">
|
|
||||||
<label for="email">Email</label>
|
|
||||||
<input
|
|
||||||
type="email"
|
|
||||||
id="email"
|
|
||||||
name="email"
|
|
||||||
placeholder="you@example.com"
|
|
||||||
required
|
|
||||||
autocomplete="email"
|
|
||||||
>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="form-group">
|
|
||||||
<label for="password">Password</label>
|
|
||||||
<input
|
|
||||||
type="password"
|
|
||||||
id="password"
|
|
||||||
name="password"
|
|
||||||
placeholder="Enter your password"
|
|
||||||
required
|
|
||||||
autocomplete="current-password"
|
|
||||||
>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<button type="submit" class="btn-login">Sign In</button>
|
|
||||||
</form>
|
|
||||||
|
|
||||||
{signup_html}
|
|
||||||
{browser_login_html}
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
"""
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
"""
|
|
||||||
OIDC Discovery endpoints.
|
|
||||||
|
|
||||||
Implements:
|
|
||||||
- GET /.well-known/openid-configuration - OIDC Discovery Document
|
|
||||||
- GET /.well-known/jwks.json - JSON Web Key Set
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
|
||||||
|
|
||||||
from backend.server.oauth.models import JWKS, OpenIDConfiguration
|
|
||||||
from backend.server.oauth.token_service import get_token_service
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
discovery_router = APIRouter(tags=["oidc-discovery"])
|
|
||||||
|
|
||||||
|
|
||||||
@discovery_router.get(
|
|
||||||
"/.well-known/openid-configuration",
|
|
||||||
response_model=OpenIDConfiguration,
|
|
||||||
)
|
|
||||||
async def openid_configuration() -> OpenIDConfiguration:
|
|
||||||
"""
|
|
||||||
OIDC Discovery Document.
|
|
||||||
|
|
||||||
Returns metadata about the OAuth 2.0 authorization server including
|
|
||||||
endpoints, supported features, and algorithms.
|
|
||||||
"""
|
|
||||||
settings = Settings()
|
|
||||||
base_url = settings.config.platform_base_url or "https://platform.agpt.co"
|
|
||||||
|
|
||||||
return OpenIDConfiguration(
|
|
||||||
issuer=base_url,
|
|
||||||
authorization_endpoint=f"{base_url}/oauth/authorize",
|
|
||||||
token_endpoint=f"{base_url}/oauth/token",
|
|
||||||
userinfo_endpoint=f"{base_url}/oauth/userinfo",
|
|
||||||
revocation_endpoint=f"{base_url}/oauth/revoke",
|
|
||||||
jwks_uri=f"{base_url}/.well-known/jwks.json",
|
|
||||||
scopes_supported=[
|
|
||||||
"openid",
|
|
||||||
"profile",
|
|
||||||
"email",
|
|
||||||
"integrations:list",
|
|
||||||
"integrations:connect",
|
|
||||||
"integrations:delete",
|
|
||||||
"agents:execute",
|
|
||||||
],
|
|
||||||
response_types_supported=["code"],
|
|
||||||
grant_types_supported=["authorization_code", "refresh_token"],
|
|
||||||
token_endpoint_auth_methods_supported=[
|
|
||||||
"client_secret_post",
|
|
||||||
"client_secret_basic",
|
|
||||||
"none", # For public clients with PKCE
|
|
||||||
],
|
|
||||||
code_challenge_methods_supported=["S256"],
|
|
||||||
subject_types_supported=["public"],
|
|
||||||
id_token_signing_alg_values_supported=["RS256"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@discovery_router.get("/.well-known/jwks.json", response_model=JWKS)
|
|
||||||
async def jwks() -> dict:
|
|
||||||
"""
|
|
||||||
JSON Web Key Set (JWKS).
|
|
||||||
|
|
||||||
Returns the public key(s) used to verify JWT signatures.
|
|
||||||
External applications can use these keys to verify access tokens
|
|
||||||
and ID tokens issued by this authorization server.
|
|
||||||
"""
|
|
||||||
token_service = get_token_service()
|
|
||||||
return token_service.get_jwks()
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth 2.0 Error Responses (RFC 6749 Section 5.2).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthErrorCode(str, Enum):
|
|
||||||
"""Standard OAuth 2.0 error codes."""
|
|
||||||
|
|
||||||
# Authorization endpoint errors (RFC 6749 Section 4.1.2.1)
|
|
||||||
INVALID_REQUEST = "invalid_request"
|
|
||||||
UNAUTHORIZED_CLIENT = "unauthorized_client"
|
|
||||||
ACCESS_DENIED = "access_denied"
|
|
||||||
UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type"
|
|
||||||
INVALID_SCOPE = "invalid_scope"
|
|
||||||
SERVER_ERROR = "server_error"
|
|
||||||
TEMPORARILY_UNAVAILABLE = "temporarily_unavailable"
|
|
||||||
|
|
||||||
# Token endpoint errors (RFC 6749 Section 5.2)
|
|
||||||
INVALID_CLIENT = "invalid_client"
|
|
||||||
INVALID_GRANT = "invalid_grant"
|
|
||||||
UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type"
|
|
||||||
|
|
||||||
# Extension errors
|
|
||||||
LOGIN_REQUIRED = "login_required"
|
|
||||||
CONSENT_REQUIRED = "consent_required"
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthErrorResponse(BaseModel):
|
|
||||||
"""OAuth error response model."""
|
|
||||||
|
|
||||||
error: str
|
|
||||||
error_description: Optional[str] = None
|
|
||||||
error_uri: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthError(Exception):
|
|
||||||
"""Base OAuth error exception."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
error: OAuthErrorCode,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
state: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.error = error
|
|
||||||
self.description = description
|
|
||||||
self.uri = uri
|
|
||||||
self.state = state
|
|
||||||
super().__init__(description or error.value)
|
|
||||||
|
|
||||||
def to_response(self) -> OAuthErrorResponse:
|
|
||||||
"""Convert to response model."""
|
|
||||||
return OAuthErrorResponse(
|
|
||||||
error=self.error.value,
|
|
||||||
error_description=self.description,
|
|
||||||
error_uri=self.uri,
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_redirect(self, redirect_uri: str) -> RedirectResponse:
|
|
||||||
"""Convert to redirect response with error in query params."""
|
|
||||||
params = {"error": self.error.value}
|
|
||||||
if self.description:
|
|
||||||
params["error_description"] = self.description
|
|
||||||
if self.uri:
|
|
||||||
params["error_uri"] = self.uri
|
|
||||||
if self.state:
|
|
||||||
params["state"] = self.state
|
|
||||||
|
|
||||||
separator = "&" if "?" in redirect_uri else "?"
|
|
||||||
url = f"{redirect_uri}{separator}{urlencode(params)}"
|
|
||||||
return RedirectResponse(url=url, status_code=302)
|
|
||||||
|
|
||||||
def to_http_exception(self, status_code: int = 400) -> HTTPException:
|
|
||||||
"""Convert to FastAPI HTTPException."""
|
|
||||||
return HTTPException(
|
|
||||||
status_code=status_code,
|
|
||||||
detail=self.to_response().model_dump(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience error classes
|
|
||||||
class InvalidRequestError(OAuthError):
|
|
||||||
"""The request is missing a required parameter or is otherwise malformed."""
|
|
||||||
|
|
||||||
def __init__(self, description: str, state: Optional[str] = None):
|
|
||||||
super().__init__(OAuthErrorCode.INVALID_REQUEST, description, state=state)
|
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedClientError(OAuthError):
|
|
||||||
"""The client is not authorized to request an authorization code."""
|
|
||||||
|
|
||||||
def __init__(self, description: str, state: Optional[str] = None):
|
|
||||||
super().__init__(OAuthErrorCode.UNAUTHORIZED_CLIENT, description, state=state)
|
|
||||||
|
|
||||||
|
|
||||||
class AccessDeniedError(OAuthError):
|
|
||||||
"""The resource owner denied the request."""
|
|
||||||
|
|
||||||
def __init__(self, description: str = "Access denied", state: Optional[str] = None):
|
|
||||||
super().__init__(OAuthErrorCode.ACCESS_DENIED, description, state=state)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidScopeError(OAuthError):
|
|
||||||
"""The requested scope is invalid, unknown, or malformed."""
|
|
||||||
|
|
||||||
def __init__(self, description: str, state: Optional[str] = None):
|
|
||||||
super().__init__(OAuthErrorCode.INVALID_SCOPE, description, state=state)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidClientError(OAuthError):
|
|
||||||
"""Client authentication failed."""
|
|
||||||
|
|
||||||
def __init__(self, description: str = "Invalid client"):
|
|
||||||
super().__init__(OAuthErrorCode.INVALID_CLIENT, description)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidGrantError(OAuthError):
|
|
||||||
"""The provided authorization code or refresh token is invalid."""
|
|
||||||
|
|
||||||
def __init__(self, description: str = "Invalid grant"):
|
|
||||||
super().__init__(OAuthErrorCode.INVALID_GRANT, description)
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedGrantTypeError(OAuthError):
|
|
||||||
"""The authorization grant type is not supported."""
|
|
||||||
|
|
||||||
def __init__(self, grant_type: str):
|
|
||||||
super().__init__(
|
|
||||||
OAuthErrorCode.UNSUPPORTED_GRANT_TYPE,
|
|
||||||
f"Grant type '{grant_type}' is not supported",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginRequiredError(OAuthError):
|
|
||||||
"""User must be logged in to complete the request."""
|
|
||||||
|
|
||||||
def __init__(self, state: Optional[str] = None):
|
|
||||||
super().__init__(
|
|
||||||
OAuthErrorCode.LOGIN_REQUIRED,
|
|
||||||
"User authentication required",
|
|
||||||
state=state,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConsentRequiredError(OAuthError):
|
|
||||||
"""User consent is required for the requested scopes."""
|
|
||||||
|
|
||||||
def __init__(self, state: Optional[str] = None):
|
|
||||||
super().__init__(
|
|
||||||
OAuthErrorCode.CONSENT_REQUIRED,
|
|
||||||
"User consent required",
|
|
||||||
state=state,
|
|
||||||
)
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
"""
|
|
||||||
Pydantic models for OAuth 2.0 requests and responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, HttpUrl
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Enums and Constants
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthScope(str, Enum):
|
|
||||||
"""Supported OAuth scopes."""
|
|
||||||
|
|
||||||
# OpenID Connect standard scopes
|
|
||||||
OPENID = "openid"
|
|
||||||
PROFILE = "profile"
|
|
||||||
EMAIL = "email"
|
|
||||||
|
|
||||||
# AutoGPT-specific scopes
|
|
||||||
INTEGRATIONS_LIST = "integrations:list"
|
|
||||||
INTEGRATIONS_CONNECT = "integrations:connect"
|
|
||||||
INTEGRATIONS_DELETE = "integrations:delete"
|
|
||||||
AGENTS_EXECUTE = "agents:execute"
|
|
||||||
|
|
||||||
|
|
||||||
SCOPE_DESCRIPTIONS: dict[str, str] = {
|
|
||||||
OAuthScope.OPENID.value: "Access your user ID",
|
|
||||||
OAuthScope.PROFILE.value: "Access your profile information (name)",
|
|
||||||
OAuthScope.EMAIL.value: "Access your email address",
|
|
||||||
OAuthScope.INTEGRATIONS_LIST.value: "View your connected integrations",
|
|
||||||
OAuthScope.INTEGRATIONS_CONNECT.value: "Connect new integrations on your behalf",
|
|
||||||
OAuthScope.INTEGRATIONS_DELETE.value: "Delete integrations on your behalf",
|
|
||||||
OAuthScope.AGENTS_EXECUTE.value: "Run agents on your behalf",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Authorization Request/Response Models
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizationRequest(BaseModel):
|
|
||||||
"""OAuth 2.0 Authorization Request (RFC 6749 Section 4.1.1)."""
|
|
||||||
|
|
||||||
response_type: Literal["code"] = Field(
|
|
||||||
..., description="Must be 'code' for authorization code flow"
|
|
||||||
)
|
|
||||||
client_id: str = Field(..., description="Client identifier")
|
|
||||||
redirect_uri: str = Field(..., description="Redirect URI after authorization")
|
|
||||||
scope: str = Field(default="", description="Space-separated list of scopes")
|
|
||||||
state: str = Field(..., description="CSRF protection token (required)")
|
|
||||||
code_challenge: str = Field(..., description="PKCE code challenge (required)")
|
|
||||||
code_challenge_method: Literal["S256"] = Field(
|
|
||||||
default="S256", description="PKCE method (only S256 supported)"
|
|
||||||
)
|
|
||||||
nonce: Optional[str] = Field(None, description="OIDC nonce for replay protection")
|
|
||||||
prompt: Optional[Literal["consent", "login", "none"]] = Field(
|
|
||||||
None, description="Prompt behavior"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConsentFormData(BaseModel):
|
|
||||||
"""Consent form submission data."""
|
|
||||||
|
|
||||||
consent_token: str = Field(..., description="CSRF token for consent")
|
|
||||||
authorize: bool = Field(..., description="Whether user authorized")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Token Request/Response Models
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TokenRequest(BaseModel):
|
|
||||||
"""OAuth 2.0 Token Request (RFC 6749 Section 4.1.3)."""
|
|
||||||
|
|
||||||
grant_type: Literal["authorization_code", "refresh_token"] = Field(
|
|
||||||
..., description="Grant type"
|
|
||||||
)
|
|
||||||
code: Optional[str] = Field(
|
|
||||||
None, description="Authorization code (for authorization_code grant)"
|
|
||||||
)
|
|
||||||
redirect_uri: Optional[str] = Field(
|
|
||||||
None, description="Must match authorization request"
|
|
||||||
)
|
|
||||||
client_id: str = Field(..., description="Client identifier")
|
|
||||||
client_secret: Optional[str] = Field(
|
|
||||||
None, description="Client secret (for confidential clients)"
|
|
||||||
)
|
|
||||||
code_verifier: Optional[str] = Field(
|
|
||||||
None, description="PKCE code verifier (for authorization_code grant)"
|
|
||||||
)
|
|
||||||
refresh_token: Optional[str] = Field(
|
|
||||||
None, description="Refresh token (for refresh_token grant)"
|
|
||||||
)
|
|
||||||
scope: Optional[str] = Field(
|
|
||||||
None, description="Requested scopes (for refresh_token grant)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
|
||||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
|
||||||
|
|
||||||
access_token: str = Field(..., description="Access token")
|
|
||||||
token_type: Literal["Bearer"] = Field(default="Bearer", description="Token type")
|
|
||||||
expires_in: int = Field(..., description="Token lifetime in seconds")
|
|
||||||
refresh_token: Optional[str] = Field(None, description="Refresh token")
|
|
||||||
scope: Optional[str] = Field(None, description="Granted scopes")
|
|
||||||
id_token: Optional[str] = Field(None, description="OIDC ID token")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# UserInfo Response Model
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class UserInfoResponse(BaseModel):
|
|
||||||
"""OIDC UserInfo Response."""
|
|
||||||
|
|
||||||
sub: str = Field(..., description="User ID (subject)")
|
|
||||||
email: Optional[str] = Field(None, description="User email")
|
|
||||||
email_verified: Optional[bool] = Field(
|
|
||||||
None, description="Whether email is verified"
|
|
||||||
)
|
|
||||||
name: Optional[str] = Field(None, description="User display name")
|
|
||||||
updated_at: Optional[int] = Field(None, description="Last profile update timestamp")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# OIDC Discovery Models
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class OpenIDConfiguration(BaseModel):
|
|
||||||
"""OIDC Discovery Document."""
|
|
||||||
|
|
||||||
issuer: str
|
|
||||||
authorization_endpoint: str
|
|
||||||
token_endpoint: str
|
|
||||||
userinfo_endpoint: str
|
|
||||||
revocation_endpoint: str
|
|
||||||
jwks_uri: str
|
|
||||||
scopes_supported: list[str]
|
|
||||||
response_types_supported: list[str]
|
|
||||||
grant_types_supported: list[str]
|
|
||||||
token_endpoint_auth_methods_supported: list[str]
|
|
||||||
code_challenge_methods_supported: list[str]
|
|
||||||
subject_types_supported: list[str]
|
|
||||||
id_token_signing_alg_values_supported: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class JWK(BaseModel):
|
|
||||||
"""JSON Web Key."""
|
|
||||||
|
|
||||||
kty: str = Field(..., description="Key type (RSA)")
|
|
||||||
use: str = Field(default="sig", description="Key use (signature)")
|
|
||||||
kid: str = Field(..., description="Key ID")
|
|
||||||
alg: str = Field(default="RS256", description="Algorithm")
|
|
||||||
n: str = Field(..., description="RSA modulus")
|
|
||||||
e: str = Field(..., description="RSA exponent")
|
|
||||||
|
|
||||||
|
|
||||||
class JWKS(BaseModel):
|
|
||||||
"""JSON Web Key Set."""
|
|
||||||
|
|
||||||
keys: list[JWK]
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Client Management Models
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterClientRequest(BaseModel):
|
|
||||||
"""Request to register a new OAuth client."""
|
|
||||||
|
|
||||||
name: str = Field(..., min_length=1, max_length=100, description="Client name")
|
|
||||||
description: Optional[str] = Field(
|
|
||||||
None, max_length=500, description="Client description"
|
|
||||||
)
|
|
||||||
logo_url: Optional[HttpUrl] = Field(None, description="Logo URL")
|
|
||||||
homepage_url: Optional[HttpUrl] = Field(None, description="Homepage URL")
|
|
||||||
privacy_policy_url: Optional[HttpUrl] = Field(
|
|
||||||
None, description="Privacy policy URL"
|
|
||||||
)
|
|
||||||
terms_of_service_url: Optional[HttpUrl] = Field(
|
|
||||||
None, description="Terms of service URL"
|
|
||||||
)
|
|
||||||
redirect_uris: list[str] = Field(
|
|
||||||
..., min_length=1, description="Allowed redirect URIs"
|
|
||||||
)
|
|
||||||
client_type: Literal["public", "confidential"] = Field(
|
|
||||||
default="public", description="Client type"
|
|
||||||
)
|
|
||||||
webhook_domains: list[str] = Field(
|
|
||||||
default_factory=list, description="Allowed webhook domains"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateClientRequest(BaseModel):
|
|
||||||
"""Request to update an OAuth client."""
|
|
||||||
|
|
||||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
|
||||||
description: Optional[str] = Field(None, max_length=500)
|
|
||||||
logo_url: Optional[HttpUrl] = None
|
|
||||||
homepage_url: Optional[HttpUrl] = None
|
|
||||||
privacy_policy_url: Optional[HttpUrl] = None
|
|
||||||
terms_of_service_url: Optional[HttpUrl] = None
|
|
||||||
redirect_uris: Optional[list[str]] = None
|
|
||||||
webhook_domains: Optional[list[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClientResponse(BaseModel):
|
|
||||||
"""OAuth client response."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
client_id: str
|
|
||||||
client_type: str
|
|
||||||
name: str
|
|
||||||
description: Optional[str]
|
|
||||||
logo_url: Optional[str]
|
|
||||||
homepage_url: Optional[str]
|
|
||||||
privacy_policy_url: Optional[str]
|
|
||||||
terms_of_service_url: Optional[str]
|
|
||||||
redirect_uris: list[str]
|
|
||||||
allowed_scopes: list[str]
|
|
||||||
webhook_domains: list[str]
|
|
||||||
status: str
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class ClientSecretResponse(BaseModel):
|
|
||||||
"""Response containing newly generated client credentials."""
|
|
||||||
|
|
||||||
client_id: str
|
|
||||||
client_secret: str = Field(
|
|
||||||
..., description="Client secret (only shown once, store securely)"
|
|
||||||
)
|
|
||||||
webhook_secret: str = Field(
|
|
||||||
...,
|
|
||||||
description="Webhook secret for HMAC signing (only shown once, store securely)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Token Introspection/Revocation Models
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TokenRevocationRequest(BaseModel):
|
|
||||||
"""Token revocation request (RFC 7009)."""
|
|
||||||
|
|
||||||
token: str = Field(..., description="Token to revoke")
|
|
||||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
|
||||||
None, description="Hint about token type"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenIntrospectionRequest(BaseModel):
|
|
||||||
"""Token introspection request (RFC 7662)."""
|
|
||||||
|
|
||||||
token: str = Field(..., description="Token to introspect")
|
|
||||||
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = Field(
|
|
||||||
None, description="Hint about token type"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenIntrospectionResponse(BaseModel):
|
|
||||||
"""Token introspection response."""
|
|
||||||
|
|
||||||
active: bool = Field(..., description="Whether the token is active")
|
|
||||||
scope: Optional[str] = Field(None, description="Token scopes")
|
|
||||||
client_id: Optional[str] = Field(
|
|
||||||
None, description="Client that token was issued to"
|
|
||||||
)
|
|
||||||
username: Optional[str] = Field(None, description="User identifier")
|
|
||||||
token_type: Optional[str] = Field(None, description="Token type")
|
|
||||||
exp: Optional[int] = Field(None, description="Expiration timestamp")
|
|
||||||
iat: Optional[int] = Field(None, description="Issued at timestamp")
|
|
||||||
sub: Optional[str] = Field(None, description="Subject (user ID)")
|
|
||||||
aud: Optional[str] = Field(None, description="Audience")
|
|
||||||
iss: Optional[str] = Field(None, description="Issuer")
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""
|
|
||||||
PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0.
|
|
||||||
|
|
||||||
RFC 7636: https://tools.ietf.org/html/rfc7636
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
|
|
||||||
def generate_code_verifier(length: int = 64) -> str:
|
|
||||||
"""
|
|
||||||
Generate a cryptographically random code verifier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
length: Length of the verifier (43-128 characters, default 64)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
URL-safe base64 encoded random string
|
|
||||||
"""
|
|
||||||
if not 43 <= length <= 128:
|
|
||||||
raise ValueError("Code verifier length must be between 43 and 128")
|
|
||||||
return secrets.token_urlsafe(length)[:length]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
|
||||||
"""
|
|
||||||
Generate a code challenge from the verifier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verifier: The code verifier string
|
|
||||||
method: Challenge method ("S256" or "plain")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The code challenge string
|
|
||||||
"""
|
|
||||||
if method == "S256":
|
|
||||||
digest = hashlib.sha256(verifier.encode("ascii")).digest()
|
|
||||||
# URL-safe base64 encoding without padding
|
|
||||||
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
|
||||||
elif method == "plain":
|
|
||||||
return verifier
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_code_challenge(
|
|
||||||
verifier: str,
|
|
||||||
challenge: str,
|
|
||||||
method: str = "S256",
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Verify that a code verifier matches the stored challenge.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verifier: The code verifier from the token request
|
|
||||||
challenge: The code challenge stored during authorization
|
|
||||||
method: The challenge method used
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the verifier matches the challenge
|
|
||||||
"""
|
|
||||||
expected = generate_code_challenge(verifier, method)
|
|
||||||
# Use constant-time comparison to prevent timing attacks
|
|
||||||
return secrets.compare_digest(expected, challenge)
|
|
||||||
@@ -1,860 +0,0 @@
|
|||||||
"""
|
|
||||||
OAuth 2.0 Authorization Server endpoints.
|
|
||||||
|
|
||||||
Implements:
|
|
||||||
- GET /oauth/authorize - Authorization endpoint
|
|
||||||
- POST /oauth/authorize/consent - Consent form submission
|
|
||||||
- POST /oauth/token - Token endpoint
|
|
||||||
- GET /oauth/userinfo - OIDC UserInfo endpoint
|
|
||||||
- POST /oauth/revoke - Token revocation endpoint
|
|
||||||
|
|
||||||
Authentication:
|
|
||||||
- X-API-Key header - API key for external apps (preferred)
|
|
||||||
- Authorization: Bearer <jwt> - JWT token authentication
|
|
||||||
- access_token cookie - Browser-based auth
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import secrets
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Form, HTTPException, Query, Request
|
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
from backend.server.oauth.consent_templates import (
|
|
||||||
render_consent_page,
|
|
||||||
render_error_page,
|
|
||||||
render_login_redirect_page,
|
|
||||||
)
|
|
||||||
from backend.server.oauth.errors import (
|
|
||||||
InvalidClientError,
|
|
||||||
InvalidRequestError,
|
|
||||||
OAuthError,
|
|
||||||
UnsupportedGrantTypeError,
|
|
||||||
)
|
|
||||||
from backend.server.oauth.models import TokenResponse, UserInfoResponse
|
|
||||||
from backend.server.oauth.service import get_oauth_service
|
|
||||||
from backend.server.oauth.token_service import get_token_service
|
|
||||||
from backend.util.rate_limiter import check_rate_limit
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
oauth_router = APIRouter(prefix="/oauth", tags=["oauth"])
|
|
||||||
|
|
||||||
# Redis key prefix and TTL for consent state storage
|
|
||||||
CONSENT_STATE_PREFIX = "oauth:consent:"
|
|
||||||
CONSENT_STATE_TTL = 600 # 10 minutes
|
|
||||||
|
|
||||||
# Redis key prefix and TTL for login redirect state storage
|
|
||||||
LOGIN_STATE_PREFIX = "oauth:login:"
|
|
||||||
LOGIN_STATE_TTL = 900 # 15 minutes (longer to allow time for login)
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_login_state(token: str, state: dict) -> None:
|
|
||||||
"""Store OAuth login state in Redis with TTL."""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.setex(
|
|
||||||
f"{LOGIN_STATE_PREFIX}{token}",
|
|
||||||
LOGIN_STATE_TTL,
|
|
||||||
json.dumps(state, default=str),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_and_delete_login_state(token: str) -> Optional[dict]:
|
|
||||||
"""Retrieve and delete login state from Redis (one-time use, atomic)."""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"{LOGIN_STATE_PREFIX}{token}"
|
|
||||||
# Use GETDEL for atomic get+delete to prevent race conditions
|
|
||||||
state_json = await redis.getdel(key)
|
|
||||||
if state_json:
|
|
||||||
return json.loads(state_json)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_consent_state(token: str, state: dict) -> None:
|
|
||||||
"""Store consent state in Redis with TTL."""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.setex(
|
|
||||||
f"{CONSENT_STATE_PREFIX}{token}",
|
|
||||||
CONSENT_STATE_TTL,
|
|
||||||
json.dumps(state, default=str),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_and_delete_consent_state(token: str) -> Optional[dict]:
|
|
||||||
"""Retrieve and delete consent state from Redis (atomic get+delete)."""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
key = f"{CONSENT_STATE_PREFIX}{token}"
|
|
||||||
# Use GETDEL for atomic get+delete to prevent race conditions
|
|
||||||
state_json = await redis.getdel(key)
|
|
||||||
if state_json:
|
|
||||||
return json.loads(state_json)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_user_id_from_request(
|
|
||||||
request: Request, strict_bearer: bool = False
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Extract user ID from request, checking API key, Authorization header, and cookie.
|
|
||||||
|
|
||||||
Supports:
|
|
||||||
1. X-API-Key header - API key authentication (preferred for external apps)
|
|
||||||
2. Authorization: Bearer <jwt> - JWT token authentication
|
|
||||||
3. access_token cookie - Cookie-based auth (for browser flows)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The incoming request
|
|
||||||
strict_bearer: If True and Bearer token is provided but invalid,
|
|
||||||
do NOT fallthrough to cookie auth (prevents auth downgrade attacks)
|
|
||||||
"""
|
|
||||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
|
||||||
|
|
||||||
from backend.data.api_key import validate_api_key
|
|
||||||
|
|
||||||
# First try X-API-Key header (for external apps)
|
|
||||||
api_key = request.headers.get("X-API-Key")
|
|
||||||
if api_key:
|
|
||||||
try:
|
|
||||||
api_key_info = await validate_api_key(api_key)
|
|
||||||
if api_key_info:
|
|
||||||
return api_key_info.user_id
|
|
||||||
except Exception:
|
|
||||||
logger.debug("API key validation failed")
|
|
||||||
|
|
||||||
# Then try Authorization header (JWT)
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if auth_header.startswith("Bearer "):
|
|
||||||
try:
|
|
||||||
token = auth_header[7:]
|
|
||||||
payload = parse_jwt_token(token)
|
|
||||||
return payload.get("sub")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("JWT token validation failed: %s", type(e).__name__)
|
|
||||||
# Security fix: If Bearer token was provided but invalid,
|
|
||||||
# don't fallthrough to weaker auth methods when strict_bearer is True
|
|
||||||
if strict_bearer:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Finally try cookie (browser-based auth)
|
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
if token:
|
|
||||||
try:
|
|
||||||
payload = parse_jwt_token(token)
|
|
||||||
return payload.get("sub")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Cookie token validation failed: %s", type(e).__name__)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_scopes(scope_str: str) -> list[str]:
|
|
||||||
"""Parse space-separated scope string into list."""
|
|
||||||
if not scope_str:
|
|
||||||
return []
|
|
||||||
return [s.strip() for s in scope_str.split() if s.strip()]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip(request: Request) -> str:
|
|
||||||
"""Get client IP address from request."""
|
|
||||||
forwarded = request.headers.get("X-Forwarded-For")
|
|
||||||
if forwarded:
|
|
||||||
return forwarded.split(",")[0].strip()
|
|
||||||
return request.client.host if request.client else "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Authorization Endpoint
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.get("/authorize", response_model=None)
|
|
||||||
async def authorize(
|
|
||||||
request: Request,
|
|
||||||
response_type: str = Query(..., description="Must be 'code'"),
|
|
||||||
client_id: str = Query(..., description="Client identifier"),
|
|
||||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
|
||||||
state: str = Query(..., description="CSRF state parameter"),
|
|
||||||
code_challenge: str = Query(..., description="PKCE code challenge"),
|
|
||||||
code_challenge_method: str = Query("S256", description="PKCE method"),
|
|
||||||
scope: str = Query("", description="Space-separated scopes"),
|
|
||||||
nonce: Optional[str] = Query(None, description="OIDC nonce"),
|
|
||||||
prompt: Optional[str] = Query(None, description="Prompt behavior"),
|
|
||||||
) -> HTMLResponse | RedirectResponse:
|
|
||||||
"""
|
|
||||||
OAuth 2.0 Authorization Endpoint.
|
|
||||||
|
|
||||||
Validates the request, checks user authentication, and either:
|
|
||||||
- Returns error if user is not authenticated (API key or JWT required)
|
|
||||||
- Shows consent page if user hasn't authorized these scopes
|
|
||||||
- Redirects with authorization code if already authorized
|
|
||||||
|
|
||||||
Authentication methods (in order of preference):
|
|
||||||
1. X-API-Key header - API key for external apps
|
|
||||||
2. Authorization: Bearer <jwt> - JWT token
|
|
||||||
3. access_token cookie - Browser-based auth
|
|
||||||
"""
|
|
||||||
# Get user ID from API key, Authorization header, or cookie
|
|
||||||
user_id = await _get_user_id_from_request(request)
|
|
||||||
|
|
||||||
# Rate limiting - use client IP as identifier for authorize endpoint
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
|
|
||||||
if not rate_result.allowed:
|
|
||||||
return HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"rate_limit_exceeded",
|
|
||||||
"Too many authorization requests. Please try again later.",
|
|
||||||
),
|
|
||||||
status_code=429,
|
|
||||||
)
|
|
||||||
|
|
||||||
oauth_service = get_oauth_service()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Validate response_type
|
|
||||||
if response_type != "code":
|
|
||||||
raise InvalidRequestError(
|
|
||||||
"Only 'code' response_type is supported", state=state
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate PKCE method
|
|
||||||
if code_challenge_method != "S256":
|
|
||||||
raise InvalidRequestError(
|
|
||||||
"Only 'S256' code_challenge_method is supported", state=state
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse scopes
|
|
||||||
scopes = _parse_scopes(scope)
|
|
||||||
|
|
||||||
# Validate client and redirect URI
|
|
||||||
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
|
|
||||||
|
|
||||||
# Check if user is authenticated
|
|
||||||
if not user_id:
|
|
||||||
# User needs to log in - store OAuth params and redirect to frontend login
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
login_token = secrets.token_urlsafe(32)
|
|
||||||
logger.info(f"Storing login state with token: {login_token}")
|
|
||||||
await _store_login_state(
|
|
||||||
login_token,
|
|
||||||
{
|
|
||||||
"client_id": client_id,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
"scopes": scopes,
|
|
||||||
"state": state,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"code_challenge_method": code_challenge_method,
|
|
||||||
"nonce": nonce,
|
|
||||||
"prompt": prompt,
|
|
||||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"expires_at": (
|
|
||||||
datetime.now(timezone.utc) + timedelta(seconds=LOGIN_STATE_TTL)
|
|
||||||
).isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(f"Login state stored successfully for token: {login_token}")
|
|
||||||
|
|
||||||
# Build redirect URL to frontend login
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"server_error", "Frontend URL not configured"
|
|
||||||
),
|
|
||||||
status_code=500,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect to frontend login with oauth_session parameter
|
|
||||||
login_url = f"{frontend_base_url}/login?oauth_session={login_token}"
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(render_login_redirect_page(login_url))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if user has already authorized these scopes
|
|
||||||
if prompt != "consent":
|
|
||||||
has_auth = await oauth_service.has_valid_authorization(
|
|
||||||
user_id, client_id, scopes
|
|
||||||
)
|
|
||||||
if has_auth:
|
|
||||||
# Skip consent, issue code directly
|
|
||||||
code = await oauth_service.create_authorization_code(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
scopes=scopes,
|
|
||||||
code_challenge=code_challenge,
|
|
||||||
code_challenge_method=code_challenge_method,
|
|
||||||
nonce=nonce,
|
|
||||||
)
|
|
||||||
redirect_url = (
|
|
||||||
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
|
||||||
|
|
||||||
# Generate consent token and store state in Redis
|
|
||||||
consent_token = secrets.token_urlsafe(32)
|
|
||||||
await _store_consent_state(
|
|
||||||
consent_token,
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"client_id": client_id,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
"scopes": scopes,
|
|
||||||
"state": state,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"code_challenge_method": code_challenge_method,
|
|
||||||
"nonce": nonce,
|
|
||||||
"expires_at": (
|
|
||||||
datetime.now(timezone.utc) + timedelta(minutes=10)
|
|
||||||
).isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render consent page
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_consent_page(
|
|
||||||
client_name=client.name,
|
|
||||||
client_logo=client.logoUrl,
|
|
||||||
scopes=scopes,
|
|
||||||
consent_token=consent_token,
|
|
||||||
action_url="/oauth/authorize/consent",
|
|
||||||
privacy_policy_url=client.privacyPolicyUrl,
|
|
||||||
terms_url=client.termsOfServiceUrl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except OAuthError as e:
|
|
||||||
# If we have a valid redirect_uri, redirect with error
|
|
||||||
# Otherwise show error page
|
|
||||||
try:
|
|
||||||
client = await oauth_service.get_client(client_id)
|
|
||||||
if client and redirect_uri in client.redirectUris:
|
|
||||||
return e.to_redirect(redirect_uri)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(e.error.value, e.description or "An error occurred"),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.post("/authorize/consent", response_model=None)
|
|
||||||
async def submit_consent(
|
|
||||||
request: Request,
|
|
||||||
consent_token: str = Form(...),
|
|
||||||
authorize: str = Form(...),
|
|
||||||
) -> HTMLResponse | RedirectResponse:
|
|
||||||
"""
|
|
||||||
Process consent form submission.
|
|
||||||
|
|
||||||
Creates authorization code and redirects to client's redirect_uri.
|
|
||||||
"""
|
|
||||||
# Rate limiting on consent submission to prevent brute force attacks
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
rate_result = await check_rate_limit(client_ip, "oauth_consent")
|
|
||||||
if not rate_result.allowed:
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"rate_limit_exceeded",
|
|
||||||
"Too many consent requests. Please try again later.",
|
|
||||||
),
|
|
||||||
status_code=429,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
oauth_service = get_oauth_service()
|
|
||||||
|
|
||||||
# Validate consent token (retrieves and deletes from Redis atomically)
|
|
||||||
consent_state = await _get_and_delete_consent_state(consent_token)
|
|
||||||
if not consent_state:
|
|
||||||
return HTMLResponse(
|
|
||||||
render_error_page("invalid_request", "Invalid or expired consent token"),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check expiration (expires_at is stored as ISO string in Redis)
|
|
||||||
expires_at = datetime.fromisoformat(consent_state["expires_at"])
|
|
||||||
if expires_at < datetime.now(timezone.utc):
|
|
||||||
return HTMLResponse(
|
|
||||||
render_error_page("invalid_request", "Consent session expired"),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
redirect_uri = consent_state["redirect_uri"]
|
|
||||||
state = consent_state["state"]
|
|
||||||
|
|
||||||
# Check if user denied
|
|
||||||
if authorize.lower() != "true":
|
|
||||||
error_params = urlencode(
|
|
||||||
{
|
|
||||||
"error": "access_denied",
|
|
||||||
"error_description": "User denied the authorization request",
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return RedirectResponse(
|
|
||||||
url=f"{redirect_uri}?{error_params}",
|
|
||||||
status_code=302,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create authorization code
|
|
||||||
code = await oauth_service.create_authorization_code(
|
|
||||||
user_id=consent_state["user_id"],
|
|
||||||
client_id=consent_state["client_id"],
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
scopes=consent_state["scopes"],
|
|
||||||
code_challenge=consent_state["code_challenge"],
|
|
||||||
code_challenge_method=consent_state["code_challenge_method"],
|
|
||||||
nonce=consent_state["nonce"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect with code
|
|
||||||
return RedirectResponse(
|
|
||||||
url=f"{redirect_uri}?{urlencode({'code': code, 'state': state})}",
|
|
||||||
status_code=302,
|
|
||||||
)
|
|
||||||
|
|
||||||
except OAuthError as e:
|
|
||||||
return e.to_redirect(redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
def _wants_json(request: Request) -> bool:
|
|
||||||
"""Check if client prefers JSON response (for frontend fetch calls)."""
|
|
||||||
accept = request.headers.get("Accept", "")
|
|
||||||
return "application/json" in accept
|
|
||||||
|
|
||||||
|
|
||||||
def _add_security_headers(response: HTMLResponse) -> HTMLResponse:
|
|
||||||
"""Add security headers to OAuth HTML responses."""
|
|
||||||
response.headers["X-Frame-Options"] = "DENY"
|
|
||||||
response.headers["Content-Security-Policy"] = "frame-ancestors 'none'"
|
|
||||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.get("/authorize/resume", response_model=None)
|
|
||||||
async def resume_authorization(
|
|
||||||
request: Request,
|
|
||||||
session_id: str = Query(..., description="OAuth login session ID"),
|
|
||||||
) -> HTMLResponse | RedirectResponse | JSONResponse:
|
|
||||||
"""
|
|
||||||
Resume OAuth authorization after user login.
|
|
||||||
|
|
||||||
This endpoint is called after the user completes login on the frontend.
|
|
||||||
It retrieves the stored OAuth parameters and continues the authorization flow.
|
|
||||||
|
|
||||||
Supports Accept: application/json header to return JSON for frontend fetch calls,
|
|
||||||
solving CORS issues with redirect responses.
|
|
||||||
"""
|
|
||||||
wants_json = _wants_json(request)
|
|
||||||
|
|
||||||
# Rate limiting - use client IP
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
rate_result = await check_rate_limit(client_ip, "oauth_authorize")
|
|
||||||
if not rate_result.allowed:
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"error": "rate_limit_exceeded",
|
|
||||||
"error_description": "Too many requests",
|
|
||||||
},
|
|
||||||
status_code=429,
|
|
||||||
)
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"rate_limit_exceeded",
|
|
||||||
"Too many authorization requests. Please try again later.",
|
|
||||||
),
|
|
||||||
status_code=429,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify user is now authenticated (use strict_bearer to prevent auth downgrade)
|
|
||||||
user_id = await _get_user_id_from_request(request, strict_bearer=True)
|
|
||||||
if not user_id:
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"error": "login_required",
|
|
||||||
"error_description": "Authentication required",
|
|
||||||
"redirect_url": f"{frontend_url}/login",
|
|
||||||
},
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"login_required",
|
|
||||||
"Authentication required. Please log in and try again.",
|
|
||||||
redirect_url=f"{frontend_url}/login",
|
|
||||||
),
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retrieve and delete login state (one-time use)
|
|
||||||
logger.info(f"Attempting to retrieve login state for session_id: {session_id}")
|
|
||||||
login_state = await _get_and_delete_login_state(session_id)
|
|
||||||
if not login_state:
|
|
||||||
logger.warning(f"Login state not found for session_id: {session_id}")
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"error": "invalid_request",
|
|
||||||
"error_description": "Invalid or expired authorization session",
|
|
||||||
},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"invalid_request",
|
|
||||||
"Invalid or expired authorization session. Please start over.",
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check expiration
|
|
||||||
expires_at = datetime.fromisoformat(login_state["expires_at"])
|
|
||||||
if expires_at < datetime.now(timezone.utc):
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"error": "invalid_request",
|
|
||||||
"error_description": "Authorization session has expired",
|
|
||||||
},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(
|
|
||||||
"invalid_request",
|
|
||||||
"Authorization session has expired. Please start over.",
|
|
||||||
),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract stored OAuth parameters
|
|
||||||
client_id = login_state["client_id"]
|
|
||||||
redirect_uri = login_state["redirect_uri"]
|
|
||||||
scopes = login_state["scopes"]
|
|
||||||
state = login_state["state"]
|
|
||||||
code_challenge = login_state["code_challenge"]
|
|
||||||
code_challenge_method = login_state["code_challenge_method"]
|
|
||||||
nonce = login_state.get("nonce")
|
|
||||||
prompt = login_state.get("prompt")
|
|
||||||
|
|
||||||
oauth_service = get_oauth_service()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Re-validate client (in case it was deactivated during login)
|
|
||||||
client = await oauth_service.validate_client(client_id, redirect_uri, scopes)
|
|
||||||
|
|
||||||
# Check if user has already authorized these scopes (skip consent if yes)
|
|
||||||
if prompt != "consent":
|
|
||||||
has_auth = await oauth_service.has_valid_authorization(
|
|
||||||
user_id, client_id, scopes
|
|
||||||
)
|
|
||||||
if has_auth:
|
|
||||||
# Skip consent, issue code directly
|
|
||||||
code = await oauth_service.create_authorization_code(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
scopes=scopes,
|
|
||||||
code_challenge=code_challenge,
|
|
||||||
code_challenge_method=code_challenge_method,
|
|
||||||
nonce=nonce,
|
|
||||||
)
|
|
||||||
redirect_url = (
|
|
||||||
f"{redirect_uri}?{urlencode({'code': code, 'state': state})}"
|
|
||||||
)
|
|
||||||
# Return JSON with redirect URL for frontend to handle
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{"redirect_url": redirect_url, "needs_consent": False}
|
|
||||||
)
|
|
||||||
return RedirectResponse(url=redirect_url, status_code=302)
|
|
||||||
|
|
||||||
# Generate consent token and store state in Redis
|
|
||||||
consent_token = secrets.token_urlsafe(32)
|
|
||||||
await _store_consent_state(
|
|
||||||
consent_token,
|
|
||||||
{
|
|
||||||
"user_id": user_id,
|
|
||||||
"client_id": client_id,
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
"scopes": scopes,
|
|
||||||
"state": state,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"code_challenge_method": code_challenge_method,
|
|
||||||
"nonce": nonce,
|
|
||||||
"expires_at": (
|
|
||||||
datetime.now(timezone.utc) + timedelta(minutes=10)
|
|
||||||
).isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# For JSON requests, return consent data instead of HTML
|
|
||||||
if wants_json:
|
|
||||||
from backend.server.oauth.models import SCOPE_DESCRIPTIONS
|
|
||||||
|
|
||||||
scope_details = [
|
|
||||||
{"scope": s, "description": SCOPE_DESCRIPTIONS.get(s, s)}
|
|
||||||
for s in scopes
|
|
||||||
]
|
|
||||||
return JSONResponse(
|
|
||||||
{
|
|
||||||
"needs_consent": True,
|
|
||||||
"consent_token": consent_token,
|
|
||||||
"client": {
|
|
||||||
"name": client.name,
|
|
||||||
"logo_url": client.logoUrl,
|
|
||||||
"privacy_policy_url": client.privacyPolicyUrl,
|
|
||||||
"terms_url": client.termsOfServiceUrl,
|
|
||||||
},
|
|
||||||
"scopes": scope_details,
|
|
||||||
"action_url": "/oauth/authorize/consent",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Render consent page (HTML response)
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_consent_page(
|
|
||||||
client_name=client.name,
|
|
||||||
client_logo=client.logoUrl,
|
|
||||||
scopes=scopes,
|
|
||||||
consent_token=consent_token,
|
|
||||||
action_url="/oauth/authorize/consent",
|
|
||||||
privacy_policy_url=client.privacyPolicyUrl,
|
|
||||||
terms_url=client.termsOfServiceUrl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except OAuthError as e:
|
|
||||||
if wants_json:
|
|
||||||
return JSONResponse(
|
|
||||||
{"error": e.error.value, "error_description": e.description},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
# If we have a valid redirect_uri, redirect with error
|
|
||||||
try:
|
|
||||||
client = await oauth_service.get_client(client_id)
|
|
||||||
if client and redirect_uri in client.redirectUris:
|
|
||||||
return e.to_redirect(redirect_uri)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return _add_security_headers(
|
|
||||||
HTMLResponse(
|
|
||||||
render_error_page(e.error.value, e.description or "An error occurred"),
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Token Endpoint
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.post("/token", response_model=TokenResponse)
|
|
||||||
async def token(
|
|
||||||
request: Request,
|
|
||||||
grant_type: str = Form(...),
|
|
||||||
code: Optional[str] = Form(None),
|
|
||||||
redirect_uri: Optional[str] = Form(None),
|
|
||||||
client_id: str = Form(...),
|
|
||||||
client_secret: Optional[str] = Form(None),
|
|
||||||
code_verifier: Optional[str] = Form(None),
|
|
||||||
refresh_token: Optional[str] = Form(None),
|
|
||||||
scope: Optional[str] = Form(None),
|
|
||||||
) -> TokenResponse:
|
|
||||||
"""
|
|
||||||
OAuth 2.0 Token Endpoint.
|
|
||||||
|
|
||||||
Supports:
|
|
||||||
- authorization_code grant (with PKCE)
|
|
||||||
- refresh_token grant
|
|
||||||
"""
|
|
||||||
# Rate limiting - use client_id as identifier
|
|
||||||
rate_result = await check_rate_limit(client_id, "oauth_token")
|
|
||||||
if not rate_result.allowed:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail="Rate limit exceeded",
|
|
||||||
headers={
|
|
||||||
"Retry-After": str(int(rate_result.retry_after or 60)),
|
|
||||||
"X-RateLimit-Remaining": "0",
|
|
||||||
"X-RateLimit-Reset": str(int(rate_result.reset_at)),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
oauth_service = get_oauth_service()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Validate client authentication
|
|
||||||
await oauth_service.validate_client_secret(client_id, client_secret)
|
|
||||||
|
|
||||||
if grant_type == "authorization_code":
|
|
||||||
# Validate required parameters
|
|
||||||
if not code:
|
|
||||||
raise InvalidRequestError("'code' is required")
|
|
||||||
if not redirect_uri:
|
|
||||||
raise InvalidRequestError("'redirect_uri' is required")
|
|
||||||
if not code_verifier:
|
|
||||||
raise InvalidRequestError("'code_verifier' is required for PKCE")
|
|
||||||
|
|
||||||
return await oauth_service.exchange_authorization_code(
|
|
||||||
code=code,
|
|
||||||
client_id=client_id,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
code_verifier=code_verifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif grant_type == "refresh_token":
|
|
||||||
if not refresh_token:
|
|
||||||
raise InvalidRequestError("'refresh_token' is required")
|
|
||||||
|
|
||||||
requested_scopes = _parse_scopes(scope) if scope else None
|
|
||||||
return await oauth_service.refresh_access_token(
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
client_id=client_id,
|
|
||||||
requested_scopes=requested_scopes,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise UnsupportedGrantTypeError(grant_type)
|
|
||||||
|
|
||||||
except OAuthError as e:
|
|
||||||
# 401 for client auth failure, 400 for other validation errors (per RFC 6749)
|
|
||||||
raise e.to_http_exception(401 if isinstance(e, InvalidClientError) else 400)
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# UserInfo Endpoint
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.get("/userinfo", response_model=UserInfoResponse)
|
|
||||||
async def userinfo(request: Request) -> UserInfoResponse:
|
|
||||||
"""
|
|
||||||
OIDC UserInfo Endpoint.
|
|
||||||
|
|
||||||
Returns user profile information based on the granted scopes.
|
|
||||||
"""
|
|
||||||
token_service = get_token_service()
|
|
||||||
|
|
||||||
# Extract bearer token
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if not auth_header.startswith("Bearer "):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Bearer token required",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
token = auth_header[7:]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Verify token
|
|
||||||
claims = token_service.verify_access_token(token)
|
|
||||||
|
|
||||||
# Check token is not revoked
|
|
||||||
token_hash = token_service.hash_token(token)
|
|
||||||
stored_token = await prisma.oauthaccesstoken.find_unique(
|
|
||||||
where={"tokenHash": token_hash}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not stored_token or stored_token.revokedAt:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token has been revoked",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update last used
|
|
||||||
await prisma.oauthaccesstoken.update(
|
|
||||||
where={"id": stored_token.id},
|
|
||||||
data={"lastUsedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get user info based on scopes
|
|
||||||
user = await prisma.user.find_unique(where={"id": claims.sub})
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
|
|
||||||
scopes = claims.scope.split()
|
|
||||||
|
|
||||||
# Build response based on scopes
|
|
||||||
email = user.email if "email" in scopes else None
|
|
||||||
email_verified = user.emailVerified if "email" in scopes else None
|
|
||||||
name = user.name if "profile" in scopes else None
|
|
||||||
updated_at = int(user.updatedAt.timestamp()) if "profile" in scopes else None
|
|
||||||
|
|
||||||
return UserInfoResponse(
|
|
||||||
sub=claims.sub,
|
|
||||||
email=email,
|
|
||||||
email_verified=email_verified,
|
|
||||||
name=name,
|
|
||||||
updated_at=updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=f"Invalid token: {str(e)}",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Token Revocation Endpoint
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@oauth_router.post("/revoke")
|
|
||||||
async def revoke(
|
|
||||||
request: Request,
|
|
||||||
token: str = Form(...),
|
|
||||||
token_type_hint: Optional[str] = Form(None),
|
|
||||||
) -> JSONResponse:
|
|
||||||
"""
|
|
||||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
|
||||||
|
|
||||||
Revokes an access token or refresh token.
|
|
||||||
"""
|
|
||||||
oauth_service = get_oauth_service()
|
|
||||||
|
|
||||||
# Note: Per RFC 7009, always return 200 even if token not found
|
|
||||||
await oauth_service.revoke_token(token, token_type_hint)
|
|
||||||
|
|
||||||
return JSONResponse(content={}, status_code=200)
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
"""
|
|
||||||
Core OAuth 2.0 service logic.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- Client validation and lookup
|
|
||||||
- Authorization code generation and exchange
|
|
||||||
- Token issuance and refresh
|
|
||||||
- User consent management
|
|
||||||
- Audit logging
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import secrets
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from prisma.enums import OAuthClientStatus
|
|
||||||
from prisma.models import OAuthAuthorization, OAuthClient, User
|
|
||||||
|
|
||||||
from backend.data.db import prisma
|
|
||||||
from backend.server.oauth.errors import (
|
|
||||||
InvalidClientError,
|
|
||||||
InvalidGrantError,
|
|
||||||
InvalidRequestError,
|
|
||||||
InvalidScopeError,
|
|
||||||
)
|
|
||||||
from backend.server.oauth.models import TokenResponse
|
|
||||||
from backend.server.oauth.pkce import verify_code_challenge
|
|
||||||
from backend.server.oauth.token_service import OAuthTokenService, get_token_service
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthService:
|
|
||||||
"""Core OAuth 2.0 service."""
|
|
||||||
|
|
||||||
def __init__(self, token_service: Optional[OAuthTokenService] = None):
|
|
||||||
self.token_service = token_service or get_token_service()
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Client Operations
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
|
|
||||||
"""Get an OAuth client by client_id."""
|
|
||||||
return await prisma.oauthclient.find_unique(where={"clientId": client_id})
|
|
||||||
|
|
||||||
async def validate_client(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
scopes: list[str],
|
|
||||||
) -> OAuthClient:
|
|
||||||
"""
|
|
||||||
Validate a client for authorization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier
|
|
||||||
redirect_uri: Requested redirect URI
|
|
||||||
scopes: Requested scopes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated OAuthClient
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidClientError: Client not found or inactive
|
|
||||||
InvalidRequestError: Invalid redirect URI
|
|
||||||
InvalidScopeError: Invalid scopes requested
|
|
||||||
"""
|
|
||||||
client = await self.get_client(client_id)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
|
||||||
|
|
||||||
if client.status != OAuthClientStatus.ACTIVE:
|
|
||||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
|
||||||
|
|
||||||
# Validate redirect URI (exact match required)
|
|
||||||
if redirect_uri not in client.redirectUris:
|
|
||||||
raise InvalidRequestError(
|
|
||||||
f"Redirect URI '{redirect_uri}' is not registered for this client"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate scopes
|
|
||||||
invalid_scopes = set(scopes) - set(client.allowedScopes)
|
|
||||||
if invalid_scopes:
|
|
||||||
raise InvalidScopeError(
|
|
||||||
f"Scopes not allowed for this client: {', '.join(invalid_scopes)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
async def validate_client_secret(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: Optional[str],
|
|
||||||
) -> OAuthClient:
|
|
||||||
"""
|
|
||||||
Validate client authentication for token endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client_id: Client identifier
|
|
||||||
client_secret: Client secret (for confidential clients)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated OAuthClient
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidClientError: Invalid client or credentials
|
|
||||||
"""
|
|
||||||
client = await self.get_client(client_id)
|
|
||||||
|
|
||||||
if not client:
|
|
||||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
|
||||||
|
|
||||||
if client.status != OAuthClientStatus.ACTIVE:
|
|
||||||
raise InvalidClientError(f"Client '{client_id}' is not active")
|
|
||||||
|
|
||||||
# Confidential clients must provide secret
|
|
||||||
if client.clientType == "confidential":
|
|
||||||
if not client_secret:
|
|
||||||
raise InvalidClientError("Client secret required")
|
|
||||||
|
|
||||||
# Hash and compare
|
|
||||||
secret_hash = self._hash_secret(
|
|
||||||
client_secret, client.clientSecretSalt or ""
|
|
||||||
)
|
|
||||||
if not secrets.compare_digest(secret_hash, client.clientSecretHash or ""):
|
|
||||||
raise InvalidClientError("Invalid client credentials")
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _hash_secret(secret: str, salt: str) -> str:
|
|
||||||
"""Hash a client secret with salt."""
|
|
||||||
return hashlib.sha256(f"{salt}{secret}".encode()).hexdigest()
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Authorization Code Operations
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
async def create_authorization_code(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
scopes: list[str],
|
|
||||||
code_challenge: str,
|
|
||||||
code_challenge_method: str = "S256",
|
|
||||||
nonce: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Create a new authorization code.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User who authorized
|
|
||||||
client_id: Client being authorized
|
|
||||||
redirect_uri: Redirect URI for callback
|
|
||||||
scopes: Granted scopes
|
|
||||||
code_challenge: PKCE code challenge
|
|
||||||
code_challenge_method: PKCE method (S256)
|
|
||||||
nonce: OIDC nonce (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Authorization code string
|
|
||||||
"""
|
|
||||||
code = secrets.token_urlsafe(32)
|
|
||||||
code_hash = self.token_service.hash_token(code)
|
|
||||||
|
|
||||||
# Get the OAuthClient to link
|
|
||||||
client = await self.get_client(client_id)
|
|
||||||
if not client:
|
|
||||||
raise InvalidClientError(f"Client '{client_id}' not found")
|
|
||||||
|
|
||||||
await prisma.oauthauthorizationcode.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"codeHash": code_hash,
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client.id,
|
|
||||||
"redirectUri": redirect_uri,
|
|
||||||
"scopes": scopes,
|
|
||||||
"codeChallenge": code_challenge,
|
|
||||||
"codeChallengeMethod": code_challenge_method,
|
|
||||||
"nonce": nonce,
|
|
||||||
"expiresAt": datetime.now(timezone.utc) + timedelta(minutes=10),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return code
|
|
||||||
|
|
||||||
async def exchange_authorization_code(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
client_id: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
code_verifier: str,
|
|
||||||
) -> TokenResponse:
|
|
||||||
"""
|
|
||||||
Exchange an authorization code for tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code: Authorization code
|
|
||||||
client_id: Client identifier
|
|
||||||
redirect_uri: Must match original redirect URI
|
|
||||||
code_verifier: PKCE code verifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TokenResponse with access token, refresh token, etc.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidGrantError: Invalid or expired code
|
|
||||||
InvalidRequestError: PKCE verification failed
|
|
||||||
"""
|
|
||||||
code_hash = self.token_service.hash_token(code)
|
|
||||||
|
|
||||||
# Find the authorization code
|
|
||||||
auth_code = await prisma.oauthauthorizationcode.find_unique(
|
|
||||||
where={"codeHash": code_hash},
|
|
||||||
include={"Client": True, "User": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not auth_code:
|
|
||||||
raise InvalidGrantError("Authorization code not found")
|
|
||||||
|
|
||||||
# Ensure Client relation is loaded
|
|
||||||
if not auth_code.Client:
|
|
||||||
raise InvalidGrantError("Authorization code client not found")
|
|
||||||
|
|
||||||
# Check if already used
|
|
||||||
if auth_code.usedAt:
|
|
||||||
# Code reuse is a security incident - revoke all tokens for this authorization
|
|
||||||
await self._revoke_tokens_for_client_user(
|
|
||||||
auth_code.Client.clientId, auth_code.userId
|
|
||||||
)
|
|
||||||
raise InvalidGrantError("Authorization code has already been used")
|
|
||||||
|
|
||||||
# Check expiration
|
|
||||||
if auth_code.expiresAt < datetime.now(timezone.utc):
|
|
||||||
raise InvalidGrantError("Authorization code has expired")
|
|
||||||
|
|
||||||
# Validate client
|
|
||||||
if auth_code.Client.clientId != client_id:
|
|
||||||
raise InvalidGrantError("Client ID mismatch")
|
|
||||||
|
|
||||||
# Validate redirect URI
|
|
||||||
if auth_code.redirectUri != redirect_uri:
|
|
||||||
raise InvalidGrantError("Redirect URI mismatch")
|
|
||||||
|
|
||||||
# Verify PKCE
|
|
||||||
if not verify_code_challenge(
|
|
||||||
code_verifier, auth_code.codeChallenge, auth_code.codeChallengeMethod
|
|
||||||
):
|
|
||||||
raise InvalidRequestError("PKCE verification failed")
|
|
||||||
|
|
||||||
# Mark code as used
|
|
||||||
await prisma.oauthauthorizationcode.update(
|
|
||||||
where={"id": auth_code.id},
|
|
||||||
data={"usedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create or update authorization record
|
|
||||||
await self._upsert_authorization(
|
|
||||||
auth_code.userId, auth_code.Client.id, auth_code.scopes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate tokens
|
|
||||||
return await self._create_tokens(
|
|
||||||
user_id=auth_code.userId,
|
|
||||||
client=auth_code.Client,
|
|
||||||
scopes=auth_code.scopes,
|
|
||||||
nonce=auth_code.nonce,
|
|
||||||
user=auth_code.User,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def refresh_access_token(
|
|
||||||
self,
|
|
||||||
refresh_token: str,
|
|
||||||
client_id: str,
|
|
||||||
requested_scopes: Optional[list[str]] = None,
|
|
||||||
) -> TokenResponse:
|
|
||||||
"""
|
|
||||||
Refresh an access token using a refresh token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
refresh_token: Refresh token string
|
|
||||||
client_id: Client identifier
|
|
||||||
requested_scopes: Optionally request fewer scopes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
New TokenResponse
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidGrantError: Invalid or expired refresh token
|
|
||||||
"""
|
|
||||||
token_hash = self.token_service.hash_token(refresh_token)
|
|
||||||
|
|
||||||
# Find the refresh token
|
|
||||||
stored_token = await prisma.oauthrefreshtoken.find_unique(
|
|
||||||
where={"tokenHash": token_hash},
|
|
||||||
include={"Client": True, "User": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not stored_token:
|
|
||||||
raise InvalidGrantError("Refresh token not found")
|
|
||||||
|
|
||||||
# Ensure Client relation is loaded
|
|
||||||
if not stored_token.Client:
|
|
||||||
raise InvalidGrantError("Refresh token client not found")
|
|
||||||
|
|
||||||
# Check if revoked
|
|
||||||
if stored_token.revokedAt:
|
|
||||||
raise InvalidGrantError("Refresh token has been revoked")
|
|
||||||
|
|
||||||
# Check expiration
|
|
||||||
if stored_token.expiresAt < datetime.now(timezone.utc):
|
|
||||||
raise InvalidGrantError("Refresh token has expired")
|
|
||||||
|
|
||||||
# Validate client
|
|
||||||
if stored_token.Client.clientId != client_id:
|
|
||||||
raise InvalidGrantError("Client ID mismatch")
|
|
||||||
|
|
||||||
# Determine scopes
|
|
||||||
scopes = stored_token.scopes
|
|
||||||
if requested_scopes:
|
|
||||||
# Can only request a subset of original scopes
|
|
||||||
invalid = set(requested_scopes) - set(stored_token.scopes)
|
|
||||||
if invalid:
|
|
||||||
raise InvalidScopeError(
|
|
||||||
f"Cannot request scopes not in original grant: {', '.join(invalid)}"
|
|
||||||
)
|
|
||||||
scopes = requested_scopes
|
|
||||||
|
|
||||||
# Generate new tokens (rotates refresh token)
|
|
||||||
return await self._create_tokens(
|
|
||||||
user_id=stored_token.userId,
|
|
||||||
client=stored_token.Client,
|
|
||||||
scopes=scopes,
|
|
||||||
user=stored_token.User,
|
|
||||||
old_refresh_token_id=stored_token.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Token Operations
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
async def _create_tokens(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client: OAuthClient,
|
|
||||||
scopes: list[str],
|
|
||||||
user: Optional[User] = None,
|
|
||||||
nonce: Optional[str] = None,
|
|
||||||
old_refresh_token_id: Optional[str] = None,
|
|
||||||
) -> TokenResponse:
|
|
||||||
"""
|
|
||||||
Create access and refresh tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
client: OAuth client
|
|
||||||
scopes: Granted scopes
|
|
||||||
user: User object (for ID token claims)
|
|
||||||
nonce: OIDC nonce
|
|
||||||
old_refresh_token_id: ID of refresh token being rotated
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TokenResponse
|
|
||||||
"""
|
|
||||||
# Generate access token
|
|
||||||
access_token, access_expires_at = self.token_service.generate_access_token(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client.clientId,
|
|
||||||
scopes=scopes,
|
|
||||||
expires_in=client.tokenLifetimeSecs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store access token hash
|
|
||||||
await prisma.oauthaccesstoken.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"tokenHash": self.token_service.hash_token(access_token),
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client.id,
|
|
||||||
"scopes": scopes,
|
|
||||||
"expiresAt": access_expires_at,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate refresh token
|
|
||||||
refresh_token = self.token_service.generate_refresh_token()
|
|
||||||
refresh_expires_at = datetime.now(timezone.utc) + timedelta(
|
|
||||||
seconds=client.refreshTokenLifetimeSecs
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma.oauthrefreshtoken.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"tokenHash": self.token_service.hash_token(refresh_token),
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client.id,
|
|
||||||
"scopes": scopes,
|
|
||||||
"expiresAt": refresh_expires_at,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Revoke old refresh token if rotating
|
|
||||||
if old_refresh_token_id:
|
|
||||||
await prisma.oauthrefreshtoken.update(
|
|
||||||
where={"id": old_refresh_token_id},
|
|
||||||
data={"revokedAt": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate ID token if openid scope requested
|
|
||||||
id_token = None
|
|
||||||
if "openid" in scopes and user:
|
|
||||||
email = user.email if "email" in scopes else None
|
|
||||||
name = user.name if "profile" in scopes else None
|
|
||||||
id_token = self.token_service.generate_id_token(
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client.clientId,
|
|
||||||
email=email,
|
|
||||||
name=name,
|
|
||||||
nonce=nonce,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Audit log
|
|
||||||
await self._audit_log(
|
|
||||||
event_type="token.issued",
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client.clientId,
|
|
||||||
details={"scopes": scopes},
|
|
||||||
)
|
|
||||||
|
|
||||||
return TokenResponse(
|
|
||||||
access_token=access_token,
|
|
||||||
token_type="Bearer",
|
|
||||||
expires_in=client.tokenLifetimeSecs,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
scope=" ".join(scopes),
|
|
||||||
id_token=id_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_token(
|
|
||||||
self,
|
|
||||||
token: str,
|
|
||||||
token_type_hint: Optional[str] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Revoke an access or refresh token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Token to revoke
|
|
||||||
token_type_hint: Hint about token type
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if token was found and revoked
|
|
||||||
"""
|
|
||||||
token_hash = self.token_service.hash_token(token)
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Try refresh token first if hinted or no hint
|
|
||||||
if token_type_hint in (None, "refresh_token"):
|
|
||||||
result = await prisma.oauthrefreshtoken.update_many(
|
|
||||||
where={"tokenHash": token_hash, "revokedAt": None},
|
|
||||||
data={"revokedAt": now},
|
|
||||||
)
|
|
||||||
if result > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Try access token
|
|
||||||
if token_type_hint in (None, "access_token"):
|
|
||||||
result = await prisma.oauthaccesstoken.update_many(
|
|
||||||
where={"tokenHash": token_hash, "revokedAt": None},
|
|
||||||
data={"revokedAt": now},
|
|
||||||
)
|
|
||||||
if result > 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _revoke_tokens_for_client_user(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Revoke all tokens for a client-user pair (security incident response)."""
|
|
||||||
client = await self.get_client(client_id)
|
|
||||||
if not client:
|
|
||||||
return
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
await prisma.oauthaccesstoken.update_many(
|
|
||||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
|
||||||
data={"revokedAt": now},
|
|
||||||
)
|
|
||||||
|
|
||||||
await prisma.oauthrefreshtoken.update_many(
|
|
||||||
where={"clientId": client.id, "userId": user_id, "revokedAt": None},
|
|
||||||
data={"revokedAt": now},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._audit_log(
|
|
||||||
event_type="tokens.revoked.security",
|
|
||||||
user_id=user_id,
|
|
||||||
client_id=client_id,
|
|
||||||
details={"reason": "authorization_code_reuse"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Authorization (Consent) Operations
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
async def get_authorization(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
) -> Optional[OAuthAuthorization]:
|
|
||||||
"""Get existing authorization for user-client pair."""
|
|
||||||
client = await self.get_client(client_id)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return await prisma.oauthauthorization.find_unique(
|
|
||||||
where={
|
|
||||||
"userId_clientId": {
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client.id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def has_valid_authorization(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
scopes: list[str],
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if user has already authorized these scopes for this client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
client_id: Client identifier
|
|
||||||
scopes: Requested scopes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if user has already authorized all requested scopes
|
|
||||||
"""
|
|
||||||
auth = await self.get_authorization(user_id, client_id)
|
|
||||||
if not auth or auth.revokedAt:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if all requested scopes are already authorized
|
|
||||||
return set(scopes).issubset(set(auth.scopes))
|
|
||||||
|
|
||||||
async def _upsert_authorization(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_db_id: str,
|
|
||||||
scopes: list[str],
|
|
||||||
) -> None:
|
|
||||||
"""Create or update an authorization record."""
|
|
||||||
existing = await prisma.oauthauthorization.find_unique(
|
|
||||||
where={
|
|
||||||
"userId_clientId": {
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_db_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
# Merge scopes
|
|
||||||
merged_scopes = list(set(existing.scopes) | set(scopes))
|
|
||||||
await prisma.oauthauthorization.update(
|
|
||||||
where={"id": existing.id},
|
|
||||||
data={"scopes": merged_scopes, "revokedAt": None},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await prisma.oauthauthorization.create(
|
|
||||||
data={ # type: ignore[typeddict-item]
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_db_id,
|
|
||||||
"scopes": scopes,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Audit Logging
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
async def _audit_log(
|
|
||||||
self,
|
|
||||||
event_type: str,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
client_id: Optional[str] = None,
|
|
||||||
grant_id: Optional[str] = None,
|
|
||||||
ip_address: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
details: Optional[dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Create an audit log entry."""
|
|
||||||
# Convert details to JSON for Prisma's Json field
|
|
||||||
details_json = json.dumps(details or {})
|
|
||||||
await prisma.oauthauditlog.create(
|
|
||||||
data={
|
|
||||||
"eventType": event_type,
|
|
||||||
"userId": user_id,
|
|
||||||
"clientId": client_id,
|
|
||||||
"grantId": grant_id,
|
|
||||||
"ipAddress": ip_address,
|
|
||||||
"userAgent": user_agent,
|
|
||||||
"details": json.loads(details_json), # type: ignore[arg-type]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
_oauth_service: Optional[OAuthService] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_oauth_service() -> OAuthService:
|
|
||||||
"""Get the singleton OAuth service instance."""
|
|
||||||
global _oauth_service
|
|
||||||
if _oauth_service is None:
|
|
||||||
_oauth_service = OAuthService()
|
|
||||||
return _oauth_service
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
"""
|
|
||||||
JWT Token Service for OAuth 2.0 Provider.
|
|
||||||
|
|
||||||
Handles generation and validation of:
|
|
||||||
- Access tokens (JWT)
|
|
||||||
- Refresh tokens (opaque)
|
|
||||||
- ID tokens (JWT, OIDC)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import secrets
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
|
||||||
RSAPrivateKey,
|
|
||||||
RSAPublicKey,
|
|
||||||
generate_private_key,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
|
|
||||||
class TokenClaims(BaseModel):
|
|
||||||
"""Decoded token claims."""
|
|
||||||
|
|
||||||
iss: str # Issuer
|
|
||||||
sub: str # Subject (user ID)
|
|
||||||
aud: str # Audience (client ID)
|
|
||||||
exp: int # Expiration timestamp
|
|
||||||
iat: int # Issued at timestamp
|
|
||||||
jti: str # JWT ID
|
|
||||||
scope: str # Space-separated scopes
|
|
||||||
client_id: str # Client ID
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthTokenService:
|
|
||||||
"""
|
|
||||||
Service for generating and validating OAuth tokens.
|
|
||||||
|
|
||||||
Uses RS256 (RSA with SHA-256) for JWT signing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, settings: Optional[Settings] = None):
|
|
||||||
self._settings = settings or Settings()
|
|
||||||
self._private_key: Optional[RSAPrivateKey] = None
|
|
||||||
self._public_key: Optional[RSAPublicKey] = None
|
|
||||||
self._algorithm = "RS256"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def issuer(self) -> str:
|
|
||||||
"""Get the token issuer URL."""
|
|
||||||
return self._settings.config.platform_base_url or "https://platform.agpt.co"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_id(self) -> str:
|
|
||||||
"""Get the key ID for JWKS."""
|
|
||||||
return self._settings.secrets.oauth_jwt_key_id or "default-key-id"
|
|
||||||
|
|
||||||
def _get_private_key(self) -> RSAPrivateKey:
|
|
||||||
"""Load or generate the private key."""
|
|
||||||
if self._private_key is not None:
|
|
||||||
return self._private_key
|
|
||||||
|
|
||||||
key_pem = self._settings.secrets.oauth_jwt_private_key
|
|
||||||
if key_pem:
|
|
||||||
loaded_key = serialization.load_pem_private_key(
|
|
||||||
key_pem.encode(), password=None
|
|
||||||
)
|
|
||||||
if not isinstance(loaded_key, RSAPrivateKey):
|
|
||||||
raise ValueError("OAuth JWT private key must be RSA")
|
|
||||||
self._private_key = loaded_key
|
|
||||||
else:
|
|
||||||
# Generate a key for development (should not be used in production)
|
|
||||||
self._private_key = generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
)
|
|
||||||
return self._private_key
|
|
||||||
|
|
||||||
def _get_public_key(self) -> RSAPublicKey:
|
|
||||||
"""Get the public key from the private key."""
|
|
||||||
if self._public_key is not None:
|
|
||||||
return self._public_key
|
|
||||||
|
|
||||||
key_pem = self._settings.secrets.oauth_jwt_public_key
|
|
||||||
if key_pem:
|
|
||||||
loaded_key = serialization.load_pem_public_key(key_pem.encode())
|
|
||||||
if not isinstance(loaded_key, RSAPublicKey):
|
|
||||||
raise ValueError("OAuth JWT public key must be RSA")
|
|
||||||
self._public_key = loaded_key
|
|
||||||
else:
|
|
||||||
self._public_key = self._get_private_key().public_key()
|
|
||||||
return self._public_key
|
|
||||||
|
|
||||||
def generate_access_token(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
scopes: list[str],
|
|
||||||
expires_in: int = 3600,
|
|
||||||
) -> tuple[str, datetime]:
|
|
||||||
"""
|
|
||||||
Generate a JWT access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID (subject)
|
|
||||||
client_id: Client ID (audience)
|
|
||||||
scopes: List of granted scopes
|
|
||||||
expires_in: Token lifetime in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (token string, expiration datetime)
|
|
||||||
"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
expires_at = now + timedelta(seconds=expires_in)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"iss": self.issuer,
|
|
||||||
"sub": user_id,
|
|
||||||
"aud": client_id,
|
|
||||||
"exp": int(expires_at.timestamp()),
|
|
||||||
"iat": int(now.timestamp()),
|
|
||||||
"jti": secrets.token_urlsafe(16),
|
|
||||||
"scope": " ".join(scopes),
|
|
||||||
"client_id": client_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
token = jwt.encode(
|
|
||||||
payload,
|
|
||||||
self._get_private_key(),
|
|
||||||
algorithm=self._algorithm,
|
|
||||||
headers={"kid": self.key_id},
|
|
||||||
)
|
|
||||||
return token, expires_at
|
|
||||||
|
|
||||||
def generate_refresh_token(self) -> str:
|
|
||||||
"""
|
|
||||||
Generate an opaque refresh token.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
URL-safe random token string
|
|
||||||
"""
|
|
||||||
return secrets.token_urlsafe(48)
|
|
||||||
|
|
||||||
def generate_id_token(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
client_id: str,
|
|
||||||
email: Optional[str] = None,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
nonce: Optional[str] = None,
|
|
||||||
expires_in: int = 3600,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate an OIDC ID token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID (subject)
|
|
||||||
client_id: Client ID (audience)
|
|
||||||
email: User's email (optional)
|
|
||||||
name: User's name (optional)
|
|
||||||
nonce: OIDC nonce for replay protection (optional)
|
|
||||||
expires_in: Token lifetime in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JWT ID token string
|
|
||||||
"""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
expires_at = now + timedelta(seconds=expires_in)
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"iss": self.issuer,
|
|
||||||
"sub": user_id,
|
|
||||||
"aud": client_id,
|
|
||||||
"exp": int(expires_at.timestamp()),
|
|
||||||
"iat": int(now.timestamp()),
|
|
||||||
"auth_time": int(now.timestamp()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if email:
|
|
||||||
payload["email"] = email
|
|
||||||
payload["email_verified"] = True
|
|
||||||
if name:
|
|
||||||
payload["name"] = name
|
|
||||||
if nonce:
|
|
||||||
payload["nonce"] = nonce
|
|
||||||
|
|
||||||
return jwt.encode(
|
|
||||||
payload,
|
|
||||||
self._get_private_key(),
|
|
||||||
algorithm=self._algorithm,
|
|
||||||
headers={"kid": self.key_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify_access_token(
|
|
||||||
self,
|
|
||||||
token: str,
|
|
||||||
expected_client_id: Optional[str] = None,
|
|
||||||
) -> TokenClaims:
|
|
||||||
"""
|
|
||||||
Verify and decode a JWT access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: JWT token string
|
|
||||||
expected_client_id: Expected client ID (audience)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decoded token claims
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
jwt.ExpiredSignatureError: Token has expired
|
|
||||||
jwt.InvalidTokenError: Token is invalid
|
|
||||||
"""
|
|
||||||
options = {}
|
|
||||||
if expected_client_id:
|
|
||||||
options["audience"] = expected_client_id
|
|
||||||
|
|
||||||
payload = jwt.decode(
|
|
||||||
token,
|
|
||||||
self._get_public_key(),
|
|
||||||
algorithms=[self._algorithm],
|
|
||||||
issuer=self.issuer,
|
|
||||||
options={"verify_aud": bool(expected_client_id)},
|
|
||||||
**options,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TokenClaims(
|
|
||||||
iss=payload["iss"],
|
|
||||||
sub=payload["sub"],
|
|
||||||
aud=payload.get("aud", payload.get("client_id", "")),
|
|
||||||
exp=payload["exp"],
|
|
||||||
iat=payload["iat"],
|
|
||||||
jti=payload["jti"],
|
|
||||||
scope=payload.get("scope", ""),
|
|
||||||
client_id=payload.get("client_id", payload.get("aud", "")),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hash_token(token: str) -> str:
|
|
||||||
"""
|
|
||||||
Hash a token for secure storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Token string to hash
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SHA-256 hash of the token
|
|
||||||
"""
|
|
||||||
return hashlib.sha256(token.encode()).hexdigest()
|
|
||||||
|
|
||||||
def get_jwks(self) -> dict:
|
|
||||||
"""
|
|
||||||
Get the JSON Web Key Set (JWKS) for public key distribution.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JWKS dictionary with public key(s)
|
|
||||||
"""
|
|
||||||
public_key = self._get_public_key()
|
|
||||||
public_numbers = public_key.public_numbers()
|
|
||||||
|
|
||||||
# Convert to base64url encoding without padding
|
|
||||||
def int_to_base64url(n: int, length: int) -> str:
|
|
||||||
data = n.to_bytes(length, byteorder="big")
|
|
||||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
|
||||||
|
|
||||||
# RSA modulus and exponent
|
|
||||||
n = int_to_base64url(public_numbers.n, (public_numbers.n.bit_length() + 7) // 8)
|
|
||||||
e = int_to_base64url(public_numbers.e, 3)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"keys": [
|
|
||||||
{
|
|
||||||
"kty": "RSA",
|
|
||||||
"use": "sig",
|
|
||||||
"kid": self.key_id,
|
|
||||||
"alg": self._algorithm,
|
|
||||||
"n": n,
|
|
||||||
"e": e,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
_token_service: Optional[OAuthTokenService] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_service() -> OAuthTokenService:
|
|
||||||
"""Get the singleton token service instance."""
|
|
||||||
global _token_service
|
|
||||||
if _token_service is None:
|
|
||||||
_token_service = OAuthTokenService()
|
|
||||||
return _token_service
|
|
||||||
@@ -21,7 +21,6 @@ import backend.data.db
|
|||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.data.user
|
import backend.data.user
|
||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.server.integrations.connect_router
|
|
||||||
import backend.server.routers.postmark.postmark
|
import backend.server.routers.postmark.postmark
|
||||||
import backend.server.routers.v1
|
import backend.server.routers.v1
|
||||||
import backend.server.v2.admin.credit_admin_routes
|
import backend.server.v2.admin.credit_admin_routes
|
||||||
@@ -45,7 +44,6 @@ from backend.integrations.providers import ProviderName
|
|||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
from backend.server.external.api import external_app
|
from backend.server.external.api import external_app
|
||||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||||
from backend.server.oauth import client_router, discovery_router, oauth_router
|
|
||||||
from backend.server.utils.cors import build_cors_params
|
from backend.server.utils.cors import build_cors_params
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||||
@@ -302,18 +300,6 @@ app.include_router(
|
|||||||
|
|
||||||
app.mount("/external-api", external_app)
|
app.mount("/external-api", external_app)
|
||||||
|
|
||||||
# OAuth Provider routes
|
|
||||||
app.include_router(oauth_router, tags=["oauth"], prefix="")
|
|
||||||
app.include_router(discovery_router, tags=["oidc-discovery"], prefix="")
|
|
||||||
app.include_router(client_router, tags=["oauth-clients"], prefix="")
|
|
||||||
|
|
||||||
# Integration Connect popup routes (for Credential Broker)
|
|
||||||
app.include_router(
|
|
||||||
backend.server.integrations.connect_router.connect_router,
|
|
||||||
tags=["integration-connect"],
|
|
||||||
prefix="",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||||
async def health():
|
async def health():
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
|
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
|
import backend.server.v2.library.db as library_db
|
||||||
|
import backend.server.v2.library.model as library_model
|
||||||
|
import backend.server.v2.store.db as store_db
|
||||||
|
import backend.server.v2.store.model as store_model
|
||||||
from backend.blocks import load_all_blocks
|
from backend.blocks import load_all_blocks
|
||||||
from backend.blocks.llm import LlmModel
|
from backend.blocks.llm import LlmModel
|
||||||
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
from backend.data.block import AnyBlockSchema, BlockCategory, BlockInfo, BlockSchema
|
||||||
@@ -14,17 +21,36 @@ from backend.server.v2.builder.model import (
|
|||||||
BlockResponse,
|
BlockResponse,
|
||||||
BlockType,
|
BlockType,
|
||||||
CountResponse,
|
CountResponse,
|
||||||
|
FilterType,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderResponse,
|
ProviderResponse,
|
||||||
SearchBlocksResponse,
|
SearchEntry,
|
||||||
)
|
)
|
||||||
from backend.util.cache import cached
|
from backend.util.cache import cached
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||||
_static_counts_cache: dict | None = None
|
|
||||||
_suggested_blocks: list[BlockInfo] | None = None
|
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||||
|
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||||
|
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
|
||||||
|
|
||||||
|
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ScoredItem:
|
||||||
|
item: SearchResultItem
|
||||||
|
filter_type: FilterType
|
||||||
|
score: float
|
||||||
|
sort_key: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SearchCacheEntry:
|
||||||
|
items: list[SearchResultItem]
|
||||||
|
total_items: dict[FilterType, int]
|
||||||
|
|
||||||
|
|
||||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||||
@@ -130,71 +156,244 @@ def get_block_by_id(block_id: str) -> BlockInfo | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def search_blocks(
|
async def update_search(user_id: str, search: SearchEntry) -> str:
|
||||||
include_blocks: bool = True,
|
|
||||||
include_integrations: bool = True,
|
|
||||||
query: str = "",
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 50,
|
|
||||||
) -> SearchBlocksResponse:
|
|
||||||
"""
|
"""
|
||||||
Get blocks based on the filter and query.
|
Upsert a search request for the user and return the search ID.
|
||||||
`providers` only applies for `integrations` filter.
|
|
||||||
"""
|
"""
|
||||||
blocks: list[AnyBlockSchema] = []
|
if search.search_id:
|
||||||
query = query.lower()
|
# Update existing search
|
||||||
|
await prisma.models.BuilderSearchHistory.prisma().update(
|
||||||
|
where={
|
||||||
|
"id": search.search_id,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"searchQuery": search.search_query or "",
|
||||||
|
"filter": search.filter or [], # type: ignore
|
||||||
|
"byCreator": search.by_creator or [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return search.search_id
|
||||||
|
else:
|
||||||
|
# Create new search
|
||||||
|
new_search = await prisma.models.BuilderSearchHistory.prisma().create(
|
||||||
|
data={
|
||||||
|
"userId": user_id,
|
||||||
|
"searchQuery": search.search_query or "",
|
||||||
|
"filter": search.filter or [], # type: ignore
|
||||||
|
"byCreator": search.by_creator or [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return new_search.id
|
||||||
|
|
||||||
total = 0
|
|
||||||
skip = (page - 1) * page_size
|
async def get_recent_searches(user_id: str, limit: int = 5) -> list[SearchEntry]:
|
||||||
take = page_size
|
"""
|
||||||
|
Get the user's most recent search requests.
|
||||||
|
"""
|
||||||
|
searches = await prisma.models.BuilderSearchHistory.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"userId": user_id,
|
||||||
|
},
|
||||||
|
order={
|
||||||
|
"updatedAt": "desc",
|
||||||
|
},
|
||||||
|
take=limit,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
SearchEntry(
|
||||||
|
search_query=s.searchQuery,
|
||||||
|
filter=s.filter, # type: ignore
|
||||||
|
by_creator=s.byCreator,
|
||||||
|
search_id=s.id,
|
||||||
|
)
|
||||||
|
for s in searches
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_sorted_search_results(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None,
|
||||||
|
filters: Sequence[FilterType],
|
||||||
|
by_creator: Sequence[str] | None = None,
|
||||||
|
) -> _SearchCacheEntry:
|
||||||
|
normalized_filters: tuple[FilterType, ...] = tuple(sorted(set(filters or [])))
|
||||||
|
normalized_creators: tuple[str, ...] = tuple(sorted(set(by_creator or [])))
|
||||||
|
return await _build_cached_search_results(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query or "",
|
||||||
|
filters=normalized_filters,
|
||||||
|
by_creator=normalized_creators,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=300, shared_cache=True)
|
||||||
|
async def _build_cached_search_results(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str,
|
||||||
|
filters: tuple[FilterType, ...],
|
||||||
|
by_creator: tuple[str, ...],
|
||||||
|
) -> _SearchCacheEntry:
|
||||||
|
normalized_query = (search_query or "").strip().lower()
|
||||||
|
|
||||||
|
include_blocks = "blocks" in filters
|
||||||
|
include_integrations = "integrations" in filters
|
||||||
|
include_library_agents = "my_agents" in filters
|
||||||
|
include_marketplace_agents = "marketplace_agents" in filters
|
||||||
|
|
||||||
|
scored_items: list[_ScoredItem] = []
|
||||||
|
total_items: dict[FilterType, int] = {
|
||||||
|
"blocks": 0,
|
||||||
|
"integrations": 0,
|
||||||
|
"marketplace_agents": 0,
|
||||||
|
"my_agents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
block_results, block_total, integration_total = _collect_block_results(
|
||||||
|
normalized_query=normalized_query,
|
||||||
|
include_blocks=include_blocks,
|
||||||
|
include_integrations=include_integrations,
|
||||||
|
)
|
||||||
|
scored_items.extend(block_results)
|
||||||
|
total_items["blocks"] = block_total
|
||||||
|
total_items["integrations"] = integration_total
|
||||||
|
|
||||||
|
if include_library_agents:
|
||||||
|
library_response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=search_query or None,
|
||||||
|
page=1,
|
||||||
|
page_size=MAX_LIBRARY_AGENT_RESULTS,
|
||||||
|
)
|
||||||
|
total_items["my_agents"] = library_response.pagination.total_items
|
||||||
|
scored_items.extend(
|
||||||
|
_build_library_items(
|
||||||
|
agents=library_response.agents,
|
||||||
|
normalized_query=normalized_query,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_marketplace_agents:
|
||||||
|
marketplace_response = await store_db.get_store_agents(
|
||||||
|
creators=list(by_creator) or None,
|
||||||
|
search_query=search_query or None,
|
||||||
|
page=1,
|
||||||
|
page_size=MAX_MARKETPLACE_AGENT_RESULTS,
|
||||||
|
)
|
||||||
|
total_items["marketplace_agents"] = marketplace_response.pagination.total_items
|
||||||
|
scored_items.extend(
|
||||||
|
_build_marketplace_items(
|
||||||
|
agents=marketplace_response.agents,
|
||||||
|
normalized_query=normalized_query,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_items = sorted(
|
||||||
|
scored_items,
|
||||||
|
key=lambda entry: (-entry.score, entry.sort_key, entry.filter_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
return _SearchCacheEntry(
|
||||||
|
items=[entry.item for entry in sorted_items],
|
||||||
|
total_items=total_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_block_results(
|
||||||
|
*,
|
||||||
|
normalized_query: str,
|
||||||
|
include_blocks: bool,
|
||||||
|
include_integrations: bool,
|
||||||
|
) -> tuple[list[_ScoredItem], int, int]:
|
||||||
|
results: list[_ScoredItem] = []
|
||||||
block_count = 0
|
block_count = 0
|
||||||
integration_count = 0
|
integration_count = 0
|
||||||
|
|
||||||
|
if not include_blocks and not include_integrations:
|
||||||
|
return results, block_count, integration_count
|
||||||
|
|
||||||
for block_type in load_all_blocks().values():
|
for block_type in load_all_blocks().values():
|
||||||
block: AnyBlockSchema = block_type()
|
block: AnyBlockSchema = block_type()
|
||||||
# Skip disabled blocks
|
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
continue
|
continue
|
||||||
# Skip blocks that don't match the query
|
|
||||||
if (
|
block_info = block.get_info()
|
||||||
query not in block.name.lower()
|
|
||||||
and query not in block.description.lower()
|
|
||||||
and not _matches_llm_model(block.input_schema, query)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
keep = False
|
|
||||||
credentials = list(block.input_schema.get_credentials_fields().values())
|
credentials = list(block.input_schema.get_credentials_fields().values())
|
||||||
if include_integrations and len(credentials) > 0:
|
is_integration = len(credentials) > 0
|
||||||
keep = True
|
|
||||||
|
if is_integration and not include_integrations:
|
||||||
|
continue
|
||||||
|
if not is_integration and not include_blocks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = _score_block(block, block_info, normalized_query)
|
||||||
|
if not _should_include_item(score, normalized_query):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filter_type: FilterType = "integrations" if is_integration else "blocks"
|
||||||
|
if is_integration:
|
||||||
integration_count += 1
|
integration_count += 1
|
||||||
if include_blocks and len(credentials) == 0:
|
else:
|
||||||
keep = True
|
|
||||||
block_count += 1
|
block_count += 1
|
||||||
|
|
||||||
if not keep:
|
results.append(
|
||||||
|
_ScoredItem(
|
||||||
|
item=block_info,
|
||||||
|
filter_type=filter_type,
|
||||||
|
score=score,
|
||||||
|
sort_key=_get_item_name(block_info),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, block_count, integration_count
|
||||||
|
|
||||||
|
|
||||||
|
def _build_library_items(
|
||||||
|
*,
|
||||||
|
agents: list[library_model.LibraryAgent],
|
||||||
|
normalized_query: str,
|
||||||
|
) -> list[_ScoredItem]:
|
||||||
|
results: list[_ScoredItem] = []
|
||||||
|
|
||||||
|
for agent in agents:
|
||||||
|
score = _score_library_agent(agent, normalized_query)
|
||||||
|
if not _should_include_item(score, normalized_query):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
total += 1
|
results.append(
|
||||||
if skip > 0:
|
_ScoredItem(
|
||||||
skip -= 1
|
item=agent,
|
||||||
continue
|
filter_type="my_agents",
|
||||||
if take > 0:
|
score=score,
|
||||||
take -= 1
|
sort_key=_get_item_name(agent),
|
||||||
blocks.append(block)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return SearchBlocksResponse(
|
return results
|
||||||
blocks=BlockResponse(
|
|
||||||
blocks=[b.get_info() for b in blocks],
|
|
||||||
pagination=Pagination(
|
def _build_marketplace_items(
|
||||||
total_items=total,
|
*,
|
||||||
total_pages=(total + page_size - 1) // page_size,
|
agents: list[store_model.StoreAgent],
|
||||||
current_page=page,
|
normalized_query: str,
|
||||||
page_size=page_size,
|
) -> list[_ScoredItem]:
|
||||||
),
|
results: list[_ScoredItem] = []
|
||||||
),
|
|
||||||
total_block_count=block_count,
|
for agent in agents:
|
||||||
total_integration_count=integration_count,
|
score = _score_store_agent(agent, normalized_query)
|
||||||
)
|
if not _should_include_item(score, normalized_query):
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
_ScoredItem(
|
||||||
|
item=agent,
|
||||||
|
filter_type="marketplace_agents",
|
||||||
|
score=score,
|
||||||
|
sort_key=_get_item_name(agent),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def get_providers(
|
def get_providers(
|
||||||
@@ -251,16 +450,12 @@ async def get_counts(user_id: str) -> CountResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
async def _get_static_counts():
|
async def _get_static_counts():
|
||||||
"""
|
"""
|
||||||
Get counts of blocks, integrations, and marketplace agents.
|
Get counts of blocks, integrations, and marketplace agents.
|
||||||
This is cached to avoid unnecessary database queries and calculations.
|
This is cached to avoid unnecessary database queries and calculations.
|
||||||
Can't use functools.cache here because the function is async.
|
|
||||||
"""
|
"""
|
||||||
global _static_counts_cache
|
|
||||||
if _static_counts_cache is not None:
|
|
||||||
return _static_counts_cache
|
|
||||||
|
|
||||||
all_blocks = 0
|
all_blocks = 0
|
||||||
input_blocks = 0
|
input_blocks = 0
|
||||||
action_blocks = 0
|
action_blocks = 0
|
||||||
@@ -287,7 +482,7 @@ async def _get_static_counts():
|
|||||||
|
|
||||||
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
|
marketplace_agents = await prisma.models.StoreAgent.prisma().count()
|
||||||
|
|
||||||
_static_counts_cache = {
|
return {
|
||||||
"all_blocks": all_blocks,
|
"all_blocks": all_blocks,
|
||||||
"input_blocks": input_blocks,
|
"input_blocks": input_blocks,
|
||||||
"action_blocks": action_blocks,
|
"action_blocks": action_blocks,
|
||||||
@@ -296,8 +491,6 @@ async def _get_static_counts():
|
|||||||
"marketplace_agents": marketplace_agents,
|
"marketplace_agents": marketplace_agents,
|
||||||
}
|
}
|
||||||
|
|
||||||
return _static_counts_cache
|
|
||||||
|
|
||||||
|
|
||||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||||
for field in schema_cls.model_fields.values():
|
for field in schema_cls.model_fields.values():
|
||||||
@@ -308,6 +501,123 @@ def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _score_block(
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
block_info: BlockInfo,
|
||||||
|
normalized_query: str,
|
||||||
|
) -> float:
|
||||||
|
if not normalized_query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
name = block_info.name.lower()
|
||||||
|
description = block_info.description.lower()
|
||||||
|
score = _score_primary_fields(name, description, normalized_query)
|
||||||
|
|
||||||
|
category_text = " ".join(
|
||||||
|
category.get("category", "").lower() for category in block_info.categories
|
||||||
|
)
|
||||||
|
score += _score_additional_field(category_text, normalized_query, 12, 6)
|
||||||
|
|
||||||
|
credentials_info = block.input_schema.get_credentials_fields_info().values()
|
||||||
|
provider_names = [
|
||||||
|
provider.value.lower()
|
||||||
|
for info in credentials_info
|
||||||
|
for provider in info.provider
|
||||||
|
]
|
||||||
|
provider_text = " ".join(provider_names)
|
||||||
|
score += _score_additional_field(provider_text, normalized_query, 15, 6)
|
||||||
|
|
||||||
|
if _matches_llm_model(block.input_schema, normalized_query):
|
||||||
|
score += 20
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _score_library_agent(
|
||||||
|
agent: library_model.LibraryAgent,
|
||||||
|
normalized_query: str,
|
||||||
|
) -> float:
|
||||||
|
if not normalized_query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
name = agent.name.lower()
|
||||||
|
description = (agent.description or "").lower()
|
||||||
|
instructions = (agent.instructions or "").lower()
|
||||||
|
|
||||||
|
score = _score_primary_fields(name, description, normalized_query)
|
||||||
|
score += _score_additional_field(instructions, normalized_query, 15, 6)
|
||||||
|
score += _score_additional_field(
|
||||||
|
agent.creator_name.lower(), normalized_query, 10, 5
|
||||||
|
)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _score_store_agent(
|
||||||
|
agent: store_model.StoreAgent,
|
||||||
|
normalized_query: str,
|
||||||
|
) -> float:
|
||||||
|
if not normalized_query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
name = agent.agent_name.lower()
|
||||||
|
description = agent.description.lower()
|
||||||
|
sub_heading = agent.sub_heading.lower()
|
||||||
|
|
||||||
|
score = _score_primary_fields(name, description, normalized_query)
|
||||||
|
score += _score_additional_field(sub_heading, normalized_query, 12, 6)
|
||||||
|
score += _score_additional_field(agent.creator.lower(), normalized_query, 10, 5)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _score_primary_fields(name: str, description: str, query: str) -> float:
|
||||||
|
score = 0.0
|
||||||
|
if name == query:
|
||||||
|
score += 120
|
||||||
|
elif name.startswith(query):
|
||||||
|
score += 90
|
||||||
|
elif query in name:
|
||||||
|
score += 60
|
||||||
|
|
||||||
|
score += SequenceMatcher(None, name, query).ratio() * 50
|
||||||
|
if description:
|
||||||
|
if query in description:
|
||||||
|
score += 30
|
||||||
|
score += SequenceMatcher(None, description, query).ratio() * 25
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _score_additional_field(
|
||||||
|
value: str,
|
||||||
|
query: str,
|
||||||
|
contains_weight: float,
|
||||||
|
similarity_weight: float,
|
||||||
|
) -> float:
|
||||||
|
if not value or not query:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
if query in value:
|
||||||
|
score += contains_weight
|
||||||
|
score += SequenceMatcher(None, value, query).ratio() * similarity_weight
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _should_include_item(score: float, normalized_query: str) -> bool:
|
||||||
|
if not normalized_query:
|
||||||
|
return True
|
||||||
|
return score >= MIN_SCORE_FOR_FILTERED_RESULTS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_item_name(item: SearchResultItem) -> str:
|
||||||
|
if isinstance(item, BlockInfo):
|
||||||
|
return item.name.lower()
|
||||||
|
if isinstance(item, library_model.LibraryAgent):
|
||||||
|
return item.name.lower()
|
||||||
|
return item.agent_name.lower()
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl_seconds=3600)
|
@cached(ttl_seconds=3600)
|
||||||
def _get_all_providers() -> dict[ProviderName, Provider]:
|
def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||||
providers: dict[ProviderName, Provider] = {}
|
providers: dict[ProviderName, Provider] = {}
|
||||||
@@ -329,13 +639,9 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
|||||||
return providers
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
@cached(ttl_seconds=3600)
|
||||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||||
global _suggested_blocks
|
suggested_blocks = []
|
||||||
|
|
||||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
|
||||||
return _suggested_blocks[:count]
|
|
||||||
|
|
||||||
_suggested_blocks = []
|
|
||||||
# Sum the number of executions for each block type
|
# Sum the number of executions for each block type
|
||||||
# Prisma cannot group by nested relations, so we do a raw query
|
# Prisma cannot group by nested relations, so we do a raw query
|
||||||
# Calculate the cutoff timestamp
|
# Calculate the cutoff timestamp
|
||||||
@@ -376,7 +682,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
|||||||
# Sort blocks by execution count
|
# Sort blocks by execution count
|
||||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
_suggested_blocks = [block[0] for block in blocks]
|
suggested_blocks = [block[0] for block in blocks]
|
||||||
|
|
||||||
# Return the top blocks
|
# Return the top blocks
|
||||||
return _suggested_blocks[:count]
|
return suggested_blocks[:count]
|
||||||
|
|||||||
@@ -18,10 +18,17 @@ FilterType = Literal[
|
|||||||
BlockType = Literal["all", "input", "action", "output"]
|
BlockType = Literal["all", "input", "action", "output"]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEntry(BaseModel):
|
||||||
|
search_query: str | None = None
|
||||||
|
filter: list[FilterType] | None = None
|
||||||
|
by_creator: list[str] | None = None
|
||||||
|
search_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# Suggestions
|
# Suggestions
|
||||||
class SuggestionsResponse(BaseModel):
|
class SuggestionsResponse(BaseModel):
|
||||||
otto_suggestions: list[str]
|
otto_suggestions: list[str]
|
||||||
recent_searches: list[str]
|
recent_searches: list[SearchEntry]
|
||||||
providers: list[ProviderName]
|
providers: list[ProviderName]
|
||||||
top_blocks: list[BlockInfo]
|
top_blocks: list[BlockInfo]
|
||||||
|
|
||||||
@@ -32,7 +39,7 @@ class BlockCategoryResponse(BaseModel):
|
|||||||
total_blocks: int
|
total_blocks: int
|
||||||
blocks: list[BlockInfo]
|
blocks: list[BlockInfo]
|
||||||
|
|
||||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
model_config = {"use_enum_values": False} # Use enum names like "AI"
|
||||||
|
|
||||||
|
|
||||||
# Input/Action/Output and see all for block categories
|
# Input/Action/Output and see all for block categories
|
||||||
@@ -53,17 +60,11 @@ class ProviderResponse(BaseModel):
|
|||||||
pagination: Pagination
|
pagination: Pagination
|
||||||
|
|
||||||
|
|
||||||
class SearchBlocksResponse(BaseModel):
|
|
||||||
blocks: BlockResponse
|
|
||||||
total_block_count: int
|
|
||||||
total_integration_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
|
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
|
||||||
|
search_id: str
|
||||||
total_items: dict[FilterType, int]
|
total_items: dict[FilterType, int]
|
||||||
page: int
|
pagination: Pagination
|
||||||
more_pages: bool
|
|
||||||
|
|
||||||
|
|
||||||
class CountResponse(BaseModel):
|
class CountResponse(BaseModel):
|
||||||
|
|||||||
@@ -6,10 +6,6 @@ from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
|||||||
|
|
||||||
import backend.server.v2.builder.db as builder_db
|
import backend.server.v2.builder.db as builder_db
|
||||||
import backend.server.v2.builder.model as builder_model
|
import backend.server.v2.builder.model as builder_model
|
||||||
import backend.server.v2.library.db as library_db
|
|
||||||
import backend.server.v2.library.model as library_model
|
|
||||||
import backend.server.v2.store.db as store_db
|
|
||||||
import backend.server.v2.store.model as store_model
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
@@ -45,7 +41,9 @@ def sanitize_query(query: str | None) -> str | None:
|
|||||||
summary="Get Builder suggestions",
|
summary="Get Builder suggestions",
|
||||||
response_model=builder_model.SuggestionsResponse,
|
response_model=builder_model.SuggestionsResponse,
|
||||||
)
|
)
|
||||||
async def get_suggestions() -> builder_model.SuggestionsResponse:
|
async def get_suggestions(
|
||||||
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
|
) -> builder_model.SuggestionsResponse:
|
||||||
"""
|
"""
|
||||||
Get all suggestions for the Blocks Menu.
|
Get all suggestions for the Blocks Menu.
|
||||||
"""
|
"""
|
||||||
@@ -55,11 +53,7 @@ async def get_suggestions() -> builder_model.SuggestionsResponse:
|
|||||||
"Help me create a list",
|
"Help me create a list",
|
||||||
"Help me feed my data to Google Maps",
|
"Help me feed my data to Google Maps",
|
||||||
],
|
],
|
||||||
recent_searches=[
|
recent_searches=await builder_db.get_recent_searches(user_id),
|
||||||
"image generation",
|
|
||||||
"deepfake",
|
|
||||||
"competitor analysis",
|
|
||||||
],
|
|
||||||
providers=[
|
providers=[
|
||||||
ProviderName.TWITTER,
|
ProviderName.TWITTER,
|
||||||
ProviderName.GITHUB,
|
ProviderName.GITHUB,
|
||||||
@@ -147,7 +141,6 @@ async def get_providers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Not using post method because on frontend, orval doesn't support Infinite Query with POST method.
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/search",
|
"/search",
|
||||||
summary="Builder search",
|
summary="Builder search",
|
||||||
@@ -157,7 +150,7 @@ async def get_providers(
|
|||||||
async def search(
|
async def search(
|
||||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||||
search_query: Annotated[str | None, fastapi.Query()] = None,
|
search_query: Annotated[str | None, fastapi.Query()] = None,
|
||||||
filter: Annotated[list[str] | None, fastapi.Query()] = None,
|
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
|
||||||
search_id: Annotated[str | None, fastapi.Query()] = None,
|
search_id: Annotated[str | None, fastapi.Query()] = None,
|
||||||
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
|
||||||
page: Annotated[int, fastapi.Query()] = 1,
|
page: Annotated[int, fastapi.Query()] = 1,
|
||||||
@@ -176,69 +169,43 @@ async def search(
|
|||||||
]
|
]
|
||||||
search_query = sanitize_query(search_query)
|
search_query = sanitize_query(search_query)
|
||||||
|
|
||||||
# Blocks&Integrations
|
# Get all possible results
|
||||||
blocks = builder_model.SearchBlocksResponse(
|
cached_results = await builder_db.get_sorted_search_results(
|
||||||
blocks=builder_model.BlockResponse(
|
user_id=user_id,
|
||||||
blocks=[],
|
search_query=search_query,
|
||||||
pagination=Pagination.empty(),
|
filters=filter,
|
||||||
),
|
by_creator=by_creator,
|
||||||
total_block_count=0,
|
|
||||||
total_integration_count=0,
|
|
||||||
)
|
)
|
||||||
if "blocks" in filter or "integrations" in filter:
|
|
||||||
blocks = builder_db.search_blocks(
|
|
||||||
include_blocks="blocks" in filter,
|
|
||||||
include_integrations="integrations" in filter,
|
|
||||||
query=search_query or "",
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Library Agents
|
# Paginate results
|
||||||
my_agents = library_model.LibraryAgentResponse(
|
total_combined_items = len(cached_results.items)
|
||||||
agents=[],
|
pagination = Pagination(
|
||||||
pagination=Pagination.empty(),
|
total_items=total_combined_items,
|
||||||
|
total_pages=(total_combined_items + page_size - 1) // page_size,
|
||||||
|
current_page=page,
|
||||||
|
page_size=page_size,
|
||||||
)
|
)
|
||||||
if "my_agents" in filter:
|
|
||||||
my_agents = await library_db.list_library_agents(
|
|
||||||
user_id=user_id,
|
|
||||||
search_term=search_query,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Marketplace Agents
|
start_idx = (page - 1) * page_size
|
||||||
marketplace_agents = store_model.StoreAgentsResponse(
|
end_idx = start_idx + page_size
|
||||||
agents=[],
|
paginated_items = cached_results.items[start_idx:end_idx]
|
||||||
pagination=Pagination.empty(),
|
|
||||||
)
|
# Update the search entry by id
|
||||||
if "marketplace_agents" in filter:
|
search_id = await builder_db.update_search(
|
||||||
marketplace_agents = await store_db.get_store_agents(
|
user_id,
|
||||||
creators=by_creator,
|
builder_model.SearchEntry(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
page=page,
|
filter=filter,
|
||||||
page_size=page_size,
|
by_creator=by_creator,
|
||||||
)
|
search_id=search_id,
|
||||||
|
),
|
||||||
more_pages = False
|
)
|
||||||
if (
|
|
||||||
blocks.blocks.pagination.current_page < blocks.blocks.pagination.total_pages
|
|
||||||
or my_agents.pagination.current_page < my_agents.pagination.total_pages
|
|
||||||
or marketplace_agents.pagination.current_page
|
|
||||||
< marketplace_agents.pagination.total_pages
|
|
||||||
):
|
|
||||||
more_pages = True
|
|
||||||
|
|
||||||
return builder_model.SearchResponse(
|
return builder_model.SearchResponse(
|
||||||
items=blocks.blocks.blocks + my_agents.agents + marketplace_agents.agents,
|
items=paginated_items,
|
||||||
total_items={
|
search_id=search_id,
|
||||||
"blocks": blocks.total_block_count,
|
total_items=cached_results.total_items,
|
||||||
"integrations": blocks.total_integration_count,
|
pagination=pagination,
|
||||||
"marketplace_agents": marketplace_agents.pagination.total_items,
|
|
||||||
"my_agents": my_agents.pagination.total_items,
|
|
||||||
},
|
|
||||||
page=page,
|
|
||||||
more_pages=more_pages,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from datetime import UTC, datetime
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from prisma.types import ProfileCreateInput
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||||
@@ -50,13 +49,13 @@ async def setup_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data=ProfileCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
username=username,
|
"username": username,
|
||||||
name=f"Test User {username}",
|
"name": f"Test User {username}",
|
||||||
description="Test user profile",
|
"description": "Test user profile",
|
||||||
links=[], # Required field - empty array for test profiles
|
"links": [], # Required field - empty array for test profiles
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create a test graph with agent input -> agent output
|
# 2. Create a test graph with agent input -> agent output
|
||||||
@@ -173,13 +172,13 @@ async def setup_llm_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data=ProfileCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
username=username,
|
"username": username,
|
||||||
name=f"Test User {username}",
|
"name": f"Test User {username}",
|
||||||
description="Test user profile for LLM tests",
|
"description": "Test user profile for LLM tests",
|
||||||
links=[], # Required field - empty array for test profiles
|
"links": [], # Required field - empty array for test profiles
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create test OpenAI credentials for the user
|
# 2. Create test OpenAI credentials for the user
|
||||||
@@ -333,13 +332,13 @@ async def setup_firecrawl_test_data():
|
|||||||
# 1b. Create a profile with username for the user (required for store agent lookup)
|
# 1b. Create a profile with username for the user (required for store agent lookup)
|
||||||
username = user.email.split("@")[0]
|
username = user.email.split("@")[0]
|
||||||
await prisma.profile.create(
|
await prisma.profile.create(
|
||||||
data=ProfileCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
username=username,
|
"username": username,
|
||||||
name=f"Test User {username}",
|
"name": f"Test User {username}",
|
||||||
description="Test user profile for Firecrawl tests",
|
"description": "Test user profile for Firecrawl tests",
|
||||||
links=[], # Required field - empty array for test profiles
|
"links": [], # Required field - empty array for test profiles
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
# NOTE: We deliberately do NOT create Firecrawl credentials for this user
|
||||||
|
|||||||
@@ -802,16 +802,18 @@ async def add_store_agent_to_library(
|
|||||||
|
|
||||||
# Create LibraryAgent entry
|
# Create LibraryAgent entry
|
||||||
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
added_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||||
data=prisma.types.LibraryAgentCreateInput(
|
data={
|
||||||
User={"connect": {"id": user_id}},
|
"User": {"connect": {"id": user_id}},
|
||||||
AgentGraph={
|
"AgentGraph": {
|
||||||
"connect": {
|
"connect": {
|
||||||
"graphVersionId": {"id": graph.id, "version": graph.version}
|
"graphVersionId": {"id": graph.id, "version": graph.version}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
isCreatedByUser=False,
|
"isCreatedByUser": False,
|
||||||
settings=SafeJson(_initialize_graph_settings(graph_model).model_dump()),
|
"settings": SafeJson(
|
||||||
),
|
_initialize_graph_settings(graph_model).model_dump()
|
||||||
|
),
|
||||||
|
},
|
||||||
include=library_agent_include(
|
include=library_agent_include(
|
||||||
user_id, include_nodes=False, include_executions=False
|
user_id, include_nodes=False, include_executions=False
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -248,9 +248,7 @@ async def log_search_term(search_query: str):
|
|||||||
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
try:
|
try:
|
||||||
await prisma.models.SearchTerms.prisma().create(
|
await prisma.models.SearchTerms.prisma().create(
|
||||||
data=prisma.types.SearchTermsCreateInput(
|
data={"searchTerm": search_query, "createdDate": date}
|
||||||
searchTerm=search_query, createdDate=date
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Fail silently here so that logging search terms doesn't break the app
|
# Fail silently here so that logging search terms doesn't break the app
|
||||||
@@ -1432,10 +1430,13 @@ async def _approve_sub_agent(
|
|||||||
|
|
||||||
# Create new version if no matching version found
|
# Create new version if no matching version found
|
||||||
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
next_version = max((v.version for v in listing.Versions or []), default=0) + 1
|
||||||
sub_agent_data = _create_sub_agent_version_data(sub_graph, heading, main_agent_name)
|
await prisma.models.StoreListingVersion.prisma(tx).create(
|
||||||
sub_agent_data["version"] = next_version
|
data={
|
||||||
sub_agent_data["storeListingId"] = listing.id
|
**_create_sub_agent_version_data(sub_graph, heading, main_agent_name),
|
||||||
await prisma.models.StoreListingVersion.prisma(tx).create(data=sub_agent_data)
|
"version": next_version,
|
||||||
|
"storeListingId": listing.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
await prisma.models.StoreListing.prisma(tx).update(
|
await prisma.models.StoreListing.prisma(tx).update(
|
||||||
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
where={"id": listing.id}, data={"hasApprovedVersion": True}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,13 @@ from tiktoken import encoding_for_model
|
|||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
# CONSTANTS #
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
# Message prefixes for important system messages that should be protected during compression
|
||||||
|
MAIN_OBJECTIVE_PREFIX = "[Main Objective Prompt]: "
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
# INTERNAL UTILITIES #
|
# INTERNAL UTILITIES #
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
@@ -63,6 +70,55 @@ def _msg_tokens(msg: dict, enc) -> int:
|
|||||||
return WRAPPER + content_tokens + tool_call_tokens
|
return WRAPPER + content_tokens + tool_call_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_message(msg: dict) -> bool:
|
||||||
|
"""Check if a message contains tool calls or results that should be protected."""
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# Check for Anthropic-style tool messages
|
||||||
|
if isinstance(content, list) and any(
|
||||||
|
isinstance(item, dict) and item.get("type") in ("tool_use", "tool_result")
|
||||||
|
for item in content
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for OpenAI-style tool calls in the message
|
||||||
|
if "tool_calls" in msg or msg.get("role") == "tool":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_objective_message(msg: dict) -> bool:
|
||||||
|
"""Check if a message contains objective/system prompts that should be absolutely protected."""
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
# Protect any message with the main objective prefix
|
||||||
|
return content.startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Carefully truncate tool message content while preserving tool structure.
|
||||||
|
Only truncates tool_result content, leaves tool_use intact.
|
||||||
|
"""
|
||||||
|
content = msg.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
# Only process tool_result items, leave tool_use blocks completely intact
|
||||||
|
if not (isinstance(item, dict) and item.get("type") == "tool_result"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_content = item.get("content", "")
|
||||||
|
if (
|
||||||
|
isinstance(result_content, str)
|
||||||
|
and _tok_len(result_content, enc) > max_tokens
|
||||||
|
):
|
||||||
|
item["content"] = _truncate_middle_tokens(result_content, enc, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||||
"""
|
"""
|
||||||
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
||||||
@@ -140,13 +196,21 @@ def compress_prompt(
|
|||||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||||
|
|
||||||
original_token_count = total_tokens()
|
original_token_count = total_tokens()
|
||||||
|
|
||||||
if original_token_count + reserve <= target_tokens:
|
if original_token_count + reserve <= target_tokens:
|
||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
# ---- STEP 0 : normalise content --------------------------------------
|
# ---- STEP 0 : normalise content --------------------------------------
|
||||||
# Convert non-string payloads to strings so token counting is coherent.
|
# Convert non-string payloads to strings so token counting is coherent.
|
||||||
for m in msgs[1:-1]: # keep the first & last intact
|
for i, m in enumerate(msgs):
|
||||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Keep first and last messages intact (unless they're tool messages)
|
||||||
|
if i == 0 or i == len(msgs) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||||
if len(content_str) > 20_000:
|
if len(content_str) > 20_000:
|
||||||
@@ -157,34 +221,45 @@ def compress_prompt(
|
|||||||
cap = start_cap
|
cap = start_cap
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
for m in msgs[1:-1]: # keep first & last intact
|
for m in msgs[1:-1]: # keep first & last intact
|
||||||
if _tok_len(m.get("content") or "", enc) > cap:
|
if _is_tool_message(m):
|
||||||
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
|
# For tool messages, only truncate tool result content, preserve structure
|
||||||
|
_truncate_tool_message_content(m, enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _is_objective_message(m):
|
||||||
|
# Never truncate objective messages - they contain the core task
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = m.get("content") or ""
|
||||||
|
if _tok_len(content, enc) > cap:
|
||||||
|
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||||
cap //= 2 # tighten the screw
|
cap //= 2 # tighten the screw
|
||||||
|
|
||||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
|
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||||
|
deletable_indices = []
|
||||||
|
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||||
|
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||||
|
deletable_indices.append(i)
|
||||||
|
|
||||||
|
if not deletable_indices:
|
||||||
|
break # nothing more we can drop
|
||||||
|
|
||||||
|
# Delete from center outward - find the index closest to center
|
||||||
centre = len(msgs) // 2
|
centre = len(msgs) // 2
|
||||||
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
|
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||||
order = [centre] + [
|
del msgs[to_delete]
|
||||||
i
|
|
||||||
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
|
|
||||||
for i in pair
|
|
||||||
]
|
|
||||||
removed = False
|
|
||||||
for i in order:
|
|
||||||
msg = msgs[i]
|
|
||||||
if "tool_calls" in msg or msg.get("role") == "tool":
|
|
||||||
continue # protect tool shells
|
|
||||||
del msgs[i]
|
|
||||||
removed = True
|
|
||||||
break
|
|
||||||
if not removed: # nothing more we can drop
|
|
||||||
break
|
|
||||||
|
|
||||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||||
cap = start_cap
|
cap = start_cap
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
for idx in (0, -1): # first and last
|
for idx in (0, -1): # first and last
|
||||||
|
if _is_tool_message(msgs[idx]):
|
||||||
|
# For tool messages at first/last position, truncate tool result content only
|
||||||
|
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
text = msgs[idx].get("content") or ""
|
text = msgs[idx].get("content") or ""
|
||||||
if _tok_len(text, enc) > cap:
|
if _tok_len(text, enc) > cap:
|
||||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||||
|
|||||||
@@ -1,228 +0,0 @@
|
|||||||
"""
|
|
||||||
Rate Limiting for External API.
|
|
||||||
|
|
||||||
Implements sliding window rate limiting using Redis for distributed systems.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RateLimitResult:
|
|
||||||
"""Result of a rate limit check."""
|
|
||||||
|
|
||||||
allowed: bool
|
|
||||||
remaining: int
|
|
||||||
reset_at: float
|
|
||||||
retry_after: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
|
||||||
"""
|
|
||||||
Redis-based sliding window rate limiter.
|
|
||||||
|
|
||||||
Supports multiple limit tiers (per-minute, per-hour, per-day).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix: str = "ratelimit"):
|
|
||||||
self.prefix = prefix
|
|
||||||
|
|
||||||
def _make_key(self, identifier: str, window: str) -> str:
|
|
||||||
"""Create a Redis key for the rate limit counter."""
|
|
||||||
return f"{self.prefix}:{identifier}:{window}"
|
|
||||||
|
|
||||||
async def check_and_increment(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
limits: dict[str, tuple[int, int]], # window_name -> (limit, window_seconds)
|
|
||||||
) -> RateLimitResult:
|
|
||||||
"""
|
|
||||||
Check rate limits and increment counters if allowed.
|
|
||||||
|
|
||||||
Uses atomic increment-first approach to prevent race conditions:
|
|
||||||
1. Increment all counters atomically
|
|
||||||
2. Check if any limit exceeded
|
|
||||||
3. If exceeded, decrement and return rate limit error
|
|
||||||
|
|
||||||
Args:
|
|
||||||
identifier: Unique identifier (e.g., client_id, client_id:user_id)
|
|
||||||
limits: Dictionary of limit configurations
|
|
||||||
e.g., {"minute": (60, 60), "hour": (1000, 3600)}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RateLimitResult with allowed status and remaining quota
|
|
||||||
"""
|
|
||||||
if not limits:
|
|
||||||
# No limits configured, allow request
|
|
||||||
return RateLimitResult(
|
|
||||||
allowed=True,
|
|
||||||
remaining=999999,
|
|
||||||
reset_at=time.time() + 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Increment all counters atomically first
|
|
||||||
incremented_keys: list[tuple[str, int, int, int]] = (
|
|
||||||
[]
|
|
||||||
) # (key, new_count, limit, window_seconds)
|
|
||||||
|
|
||||||
for window_name, (limit, window_seconds) in limits.items():
|
|
||||||
key = self._make_key(identifier, window_name)
|
|
||||||
|
|
||||||
# Atomic increment
|
|
||||||
new_count = await redis.incr(key)
|
|
||||||
|
|
||||||
# Set expiry if this is a new key
|
|
||||||
if new_count == 1:
|
|
||||||
await redis.expire(key, window_seconds)
|
|
||||||
|
|
||||||
incremented_keys.append((key, new_count, limit, window_seconds))
|
|
||||||
|
|
||||||
# Check if any limit exceeded
|
|
||||||
for key, new_count, limit, window_seconds in incremented_keys:
|
|
||||||
if new_count > limit:
|
|
||||||
# Rate limit exceeded - decrement all counters we just incremented
|
|
||||||
for decr_key, _, _, _ in incremented_keys:
|
|
||||||
await redis.decr(decr_key)
|
|
||||||
|
|
||||||
ttl = await redis.ttl(key)
|
|
||||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
|
||||||
|
|
||||||
return RateLimitResult(
|
|
||||||
allowed=False,
|
|
||||||
remaining=0,
|
|
||||||
reset_at=reset_at,
|
|
||||||
retry_after=ttl if ttl > 0 else window_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# All limits passed
|
|
||||||
min_remaining = float("inf")
|
|
||||||
earliest_reset = current_time
|
|
||||||
|
|
||||||
for key, new_count, limit, window_seconds in incremented_keys:
|
|
||||||
remaining = max(0, limit - new_count)
|
|
||||||
min_remaining = min(min_remaining, remaining)
|
|
||||||
|
|
||||||
ttl = await redis.ttl(key)
|
|
||||||
reset_at = current_time + (ttl if ttl > 0 else window_seconds)
|
|
||||||
earliest_reset = max(earliest_reset, reset_at)
|
|
||||||
|
|
||||||
return RateLimitResult(
|
|
||||||
allowed=True,
|
|
||||||
remaining=int(min_remaining),
|
|
||||||
reset_at=earliest_reset,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_remaining(
|
|
||||||
self,
|
|
||||||
identifier: str,
|
|
||||||
limits: dict[str, tuple[int, int]],
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""
|
|
||||||
Get remaining quota for all windows without incrementing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
identifier: Unique identifier
|
|
||||||
limits: Dictionary of limit configurations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of remaining quota per window
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
remaining = {}
|
|
||||||
|
|
||||||
for window_name, (limit, _) in limits.items():
|
|
||||||
key = self._make_key(identifier, window_name)
|
|
||||||
count = await redis.get(key)
|
|
||||||
current_count = int(count) if count else 0
|
|
||||||
remaining[window_name] = max(0, limit - current_count)
|
|
||||||
|
|
||||||
return remaining
|
|
||||||
|
|
||||||
async def reset(self, identifier: str, window: Optional[str] = None) -> None:
|
|
||||||
"""
|
|
||||||
Reset rate limit counters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
identifier: Unique identifier
|
|
||||||
window: Optional specific window to reset (resets all if None)
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
if window:
|
|
||||||
key = self._make_key(identifier, window)
|
|
||||||
await redis.delete(key)
|
|
||||||
else:
|
|
||||||
# Delete known window keys instead of scanning
|
|
||||||
# This avoids potentially slow scan operations with many keys
|
|
||||||
known_windows = ["minute", "hour", "day"]
|
|
||||||
keys_to_delete = [self._make_key(identifier, w) for w in known_windows]
|
|
||||||
# Delete all in one call (Redis handles non-existent keys gracefully)
|
|
||||||
if keys_to_delete:
|
|
||||||
await redis.delete(*keys_to_delete)
|
|
||||||
|
|
||||||
|
|
||||||
# Default rate limits for different endpoints
|
|
||||||
DEFAULT_RATE_LIMITS = {
|
|
||||||
# OAuth endpoints
|
|
||||||
"oauth_authorize": {"minute": (30, 60)}, # 30/min per IP
|
|
||||||
"oauth_token": {"minute": (20, 60)}, # 20/min per client
|
|
||||||
"oauth_consent": {"minute": (20, 60)}, # 20/min per IP for consent submission
|
|
||||||
# External API endpoints
|
|
||||||
"api_execute": {
|
|
||||||
"minute": (10, 60),
|
|
||||||
"hour": (100, 3600),
|
|
||||||
}, # 10/min, 100/hour per client+user
|
|
||||||
"api_read": {
|
|
||||||
"minute": (60, 60),
|
|
||||||
"hour": (1000, 3600),
|
|
||||||
}, # 60/min, 1000/hour per client+user
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
_rate_limiter: Optional[RateLimiter] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_rate_limiter() -> RateLimiter:
|
|
||||||
"""Get the singleton rate limiter instance."""
|
|
||||||
global _rate_limiter
|
|
||||||
if _rate_limiter is None:
|
|
||||||
_rate_limiter = RateLimiter()
|
|
||||||
return _rate_limiter
|
|
||||||
|
|
||||||
|
|
||||||
async def check_rate_limit(
|
|
||||||
identifier: str,
|
|
||||||
limit_type: str,
|
|
||||||
) -> RateLimitResult:
|
|
||||||
"""
|
|
||||||
Convenience function to check rate limits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
identifier: Unique identifier for the rate limit
|
|
||||||
limit_type: Type of limit from DEFAULT_RATE_LIMITS
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RateLimitResult
|
|
||||||
"""
|
|
||||||
limits = DEFAULT_RATE_LIMITS.get(limit_type)
|
|
||||||
if not limits:
|
|
||||||
# No rate limit configured, allow
|
|
||||||
return RateLimitResult(
|
|
||||||
allowed=True,
|
|
||||||
remaining=999999,
|
|
||||||
reset_at=time.time() + 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
rate_limiter = get_rate_limiter()
|
|
||||||
return await rate_limiter.check_and_increment(identifier, limits)
|
|
||||||
@@ -651,23 +651,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
|
|
||||||
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
ayrshare_api_key: str = Field(default="", description="Ayrshare API Key")
|
||||||
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
ayrshare_jwt_key: str = Field(default="", description="Ayrshare private Key")
|
||||||
|
|
||||||
# OAuth Provider JWT keys
|
|
||||||
oauth_jwt_private_key: str = Field(
|
|
||||||
default="",
|
|
||||||
description="RSA private key for signing OAuth tokens (PEM format). "
|
|
||||||
"If not set, a development key will be auto-generated.",
|
|
||||||
)
|
|
||||||
oauth_jwt_public_key: str = Field(
|
|
||||||
default="",
|
|
||||||
description="RSA public key for verifying OAuth tokens (PEM format). "
|
|
||||||
"If not set, derived from private key.",
|
|
||||||
)
|
|
||||||
oauth_jwt_key_id: str = Field(
|
|
||||||
default="autogpt-oauth-key-1",
|
|
||||||
description="Key ID (kid) for JWKS. Used to identify the signing key.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add more secret fields as needed
|
# Add more secret fields as needed
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
"""
|
|
||||||
Time utilities for the backend.
|
|
||||||
|
|
||||||
Common datetime operations used across the codebase.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
|
|
||||||
def expiration_datetime(seconds: int) -> datetime:
|
|
||||||
"""
|
|
||||||
Calculate an expiration datetime from now.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seconds: Number of seconds until expiration
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Datetime when the item will expire (UTC)
|
|
||||||
"""
|
|
||||||
return datetime.now(timezone.utc) + timedelta(seconds=seconds)
|
|
||||||
|
|
||||||
|
|
||||||
def is_expired(dt: datetime) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a datetime has passed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dt: The datetime to check (should be timezone-aware)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the datetime is in the past
|
|
||||||
"""
|
|
||||||
return dt < datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
def utc_now() -> datetime:
|
|
||||||
"""
|
|
||||||
Get the current UTC time.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Current datetime in UTC
|
|
||||||
"""
|
|
||||||
return datetime.now(timezone.utc)
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""
|
|
||||||
URL and domain validation utilities.
|
|
||||||
|
|
||||||
Common URL validation operations used across the codebase.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def matches_domain_pattern(hostname: str, domain_pattern: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a hostname matches a domain pattern.
|
|
||||||
|
|
||||||
Supports wildcard patterns (*.example.com) which match:
|
|
||||||
- The base domain (example.com)
|
|
||||||
- Any subdomain (sub.example.com, deep.sub.example.com)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hostname: The hostname to check (e.g., "api.example.com")
|
|
||||||
domain_pattern: The pattern to match against (e.g., "*.example.com" or "example.com")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the hostname matches the pattern
|
|
||||||
"""
|
|
||||||
hostname = hostname.lower()
|
|
||||||
domain_pattern = domain_pattern.lower()
|
|
||||||
|
|
||||||
if domain_pattern.startswith("*."):
|
|
||||||
# Wildcard domain - matches base and any subdomains
|
|
||||||
base_domain = domain_pattern[2:]
|
|
||||||
return hostname == base_domain or hostname.endswith("." + base_domain)
|
|
||||||
|
|
||||||
# Exact match
|
|
||||||
return hostname == domain_pattern
|
|
||||||
|
|
||||||
|
|
||||||
def hostname_matches_any_domain(hostname: str, allowed_domains: list[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a hostname matches any of the allowed domain patterns.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hostname: The hostname to check
|
|
||||||
allowed_domains: List of allowed domain patterns (supports wildcards)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the hostname matches any pattern
|
|
||||||
"""
|
|
||||||
return any(matches_domain_pattern(hostname, domain) for domain in allowed_domains)
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
-- CreateEnum
|
|
||||||
CREATE TYPE "OAuthClientStatus" AS ENUM ('ACTIVE', 'SUSPENDED');
|
|
||||||
|
|
||||||
-- CreateEnum
|
|
||||||
CREATE TYPE "CredentialGrantPermission" AS ENUM ('USE', 'DELETE');
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthClient" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"clientSecretHash" TEXT,
|
|
||||||
"clientSecretSalt" TEXT,
|
|
||||||
"clientType" TEXT NOT NULL,
|
|
||||||
"name" TEXT NOT NULL,
|
|
||||||
"description" TEXT,
|
|
||||||
"logoUrl" TEXT,
|
|
||||||
"homepageUrl" TEXT,
|
|
||||||
"privacyPolicyUrl" TEXT,
|
|
||||||
"termsOfServiceUrl" TEXT,
|
|
||||||
"redirectUris" TEXT[],
|
|
||||||
"allowedScopes" TEXT[],
|
|
||||||
"webhookDomains" TEXT[],
|
|
||||||
"requirePkce" BOOLEAN NOT NULL DEFAULT true,
|
|
||||||
"tokenLifetimeSecs" INTEGER NOT NULL DEFAULT 3600,
|
|
||||||
"refreshTokenLifetimeSecs" INTEGER NOT NULL DEFAULT 2592000,
|
|
||||||
"status" "OAuthClientStatus" NOT NULL DEFAULT 'ACTIVE',
|
|
||||||
"ownerId" TEXT NOT NULL,
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthClient_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthAuthorization" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"scopes" TEXT[],
|
|
||||||
"revokedAt" TIMESTAMP(3),
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthAuthorization_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthAuthorizationCode" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"codeHash" TEXT NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"redirectUri" TEXT NOT NULL,
|
|
||||||
"scopes" TEXT[],
|
|
||||||
"nonce" TEXT,
|
|
||||||
"codeChallenge" TEXT NOT NULL,
|
|
||||||
"codeChallengeMethod" TEXT NOT NULL DEFAULT 'S256',
|
|
||||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"usedAt" TIMESTAMP(3),
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthAuthorizationCode_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthAccessToken" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"tokenHash" TEXT NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"scopes" TEXT[],
|
|
||||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"revokedAt" TIMESTAMP(3),
|
|
||||||
"lastUsedAt" TIMESTAMP(3),
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthAccessToken_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthRefreshToken" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"tokenHash" TEXT NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"scopes" TEXT[],
|
|
||||||
"expiresAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"revokedAt" TIMESTAMP(3),
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthRefreshToken_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "CredentialGrant" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"credentialId" TEXT NOT NULL,
|
|
||||||
"provider" TEXT NOT NULL,
|
|
||||||
"grantedScopes" TEXT[],
|
|
||||||
"permissions" "CredentialGrantPermission"[],
|
|
||||||
"expiresAt" TIMESTAMP(3),
|
|
||||||
"revokedAt" TIMESTAMP(3),
|
|
||||||
"lastUsedAt" TIMESTAMP(3),
|
|
||||||
|
|
||||||
CONSTRAINT "CredentialGrant_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "OAuthAuditLog" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"eventType" TEXT NOT NULL,
|
|
||||||
"userId" TEXT,
|
|
||||||
"clientId" TEXT,
|
|
||||||
"grantId" TEXT,
|
|
||||||
"ipAddress" TEXT,
|
|
||||||
"userAgent" TEXT,
|
|
||||||
"details" JSONB NOT NULL DEFAULT '{}',
|
|
||||||
|
|
||||||
CONSTRAINT "OAuthAuditLog_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateTable
|
|
||||||
CREATE TABLE "ExecutionWebhook" (
|
|
||||||
"id" TEXT NOT NULL,
|
|
||||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
"executionId" TEXT NOT NULL,
|
|
||||||
"webhookUrl" TEXT NOT NULL,
|
|
||||||
"clientId" TEXT NOT NULL,
|
|
||||||
"userId" TEXT NOT NULL,
|
|
||||||
"secret" TEXT,
|
|
||||||
|
|
||||||
CONSTRAINT "ExecutionWebhook_pkey" PRIMARY KEY ("id")
|
|
||||||
);
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "OAuthClient_clientId_key" ON "OAuthClient"("clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthClient_clientId_idx" ON "OAuthClient"("clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthClient_ownerId_idx" ON "OAuthClient"("ownerId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthClient_status_idx" ON "OAuthClient"("status");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuthorization_userId_idx" ON "OAuthAuthorization"("userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuthorization_clientId_idx" ON "OAuthAuthorization"("clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "OAuthAuthorization_userId_clientId_key" ON "OAuthAuthorization"("userId", "clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "OAuthAuthorizationCode_codeHash_key" ON "OAuthAuthorizationCode"("codeHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuthorizationCode_codeHash_idx" ON "OAuthAuthorizationCode"("codeHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuthorizationCode_expiresAt_idx" ON "OAuthAuthorizationCode"("expiresAt");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "OAuthAccessToken_tokenHash_key" ON "OAuthAccessToken"("tokenHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAccessToken_tokenHash_idx" ON "OAuthAccessToken"("tokenHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAccessToken_userId_clientId_idx" ON "OAuthAccessToken"("userId", "clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAccessToken_expiresAt_idx" ON "OAuthAccessToken"("expiresAt");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "OAuthRefreshToken_tokenHash_key" ON "OAuthRefreshToken"("tokenHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthRefreshToken_tokenHash_idx" ON "OAuthRefreshToken"("tokenHash");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthRefreshToken_expiresAt_idx" ON "OAuthRefreshToken"("expiresAt");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "CredentialGrant_userId_clientId_idx" ON "CredentialGrant"("userId", "clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "CredentialGrant_clientId_idx" ON "CredentialGrant"("clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE UNIQUE INDEX "CredentialGrant_userId_clientId_credentialId_key" ON "CredentialGrant"("userId", "clientId", "credentialId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuditLog_createdAt_idx" ON "OAuthAuditLog"("createdAt");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuditLog_eventType_idx" ON "OAuthAuditLog"("eventType");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuditLog_userId_idx" ON "OAuthAuditLog"("userId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "OAuthAuditLog_clientId_idx" ON "OAuthAuditLog"("clientId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "ExecutionWebhook_executionId_idx" ON "ExecutionWebhook"("executionId");
|
|
||||||
|
|
||||||
-- CreateIndex
|
|
||||||
CREATE INDEX "ExecutionWebhook_clientId_idx" ON "ExecutionWebhook"("clientId");
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthClient" ADD CONSTRAINT "OAuthClient_ownerId_fkey" FOREIGN KEY ("ownerId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAuthorization" ADD CONSTRAINT "OAuthAuthorization_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAuthorizationCode" ADD CONSTRAINT "OAuthAuthorizationCode_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthAccessToken" ADD CONSTRAINT "OAuthAccessToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "OAuthRefreshToken" ADD CONSTRAINT "OAuthRefreshToken_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
|
|
||||||
-- AddForeignKey
|
|
||||||
ALTER TABLE "CredentialGrant" ADD CONSTRAINT "CredentialGrant_clientId_fkey" FOREIGN KEY ("clientId") REFERENCES "OAuthClient"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
-- AlterTable
|
|
||||||
ALTER TABLE "platform"."OAuthClient" ADD COLUMN "webhookSecret" TEXT;
|
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
-- Create BuilderSearchHistory table
|
||||||
|
CREATE TABLE "BuilderSearchHistory" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"searchQuery" TEXT NOT NULL,
|
||||||
|
"filter" TEXT[] DEFAULT ARRAY[]::TEXT[],
|
||||||
|
"byCreator" TEXT[] DEFAULT ARRAY[]::TEXT[],
|
||||||
|
|
||||||
|
CONSTRAINT "BuilderSearchHistory_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Define User foreign relation
|
||||||
|
ALTER TABLE "BuilderSearchHistory" ADD CONSTRAINT "BuilderSearchHistory_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
@@ -53,6 +53,7 @@ model User {
|
|||||||
|
|
||||||
Profile Profile[]
|
Profile Profile[]
|
||||||
UserOnboarding UserOnboarding?
|
UserOnboarding UserOnboarding?
|
||||||
|
BuilderSearchHistory BuilderSearchHistory[]
|
||||||
StoreListings StoreListing[]
|
StoreListings StoreListing[]
|
||||||
StoreListingReviews StoreListingReview[]
|
StoreListingReviews StoreListingReview[]
|
||||||
StoreVersionsReviewed StoreListingVersion[]
|
StoreVersionsReviewed StoreListingVersion[]
|
||||||
@@ -60,14 +61,6 @@ model User {
|
|||||||
IntegrationWebhooks IntegrationWebhook[]
|
IntegrationWebhooks IntegrationWebhook[]
|
||||||
NotificationBatches UserNotificationBatch[]
|
NotificationBatches UserNotificationBatch[]
|
||||||
PendingHumanReviews PendingHumanReview[]
|
PendingHumanReviews PendingHumanReview[]
|
||||||
|
|
||||||
// OAuth Provider relations
|
|
||||||
OAuthClientsOwned OAuthClient[] @relation("OAuthClientOwner")
|
|
||||||
OAuthAuthorizations OAuthAuthorization[]
|
|
||||||
OAuthAuthorizationCodes OAuthAuthorizationCode[]
|
|
||||||
OAuthAccessTokens OAuthAccessToken[]
|
|
||||||
OAuthRefreshTokens OAuthRefreshToken[]
|
|
||||||
CredentialGrants CredentialGrant[]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum OnboardingStep {
|
enum OnboardingStep {
|
||||||
@@ -122,6 +115,19 @@ model UserOnboarding {
|
|||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model BuilderSearchHistory {
|
||||||
|
id String @id @default(uuid())
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
|
||||||
|
searchQuery String
|
||||||
|
filter String[] @default([])
|
||||||
|
byCreator String[] @default([])
|
||||||
|
|
||||||
|
userId String
|
||||||
|
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
}
|
||||||
|
|
||||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||||
model AgentGraph {
|
model AgentGraph {
|
||||||
id String @default(uuid())
|
id String @default(uuid())
|
||||||
@@ -709,11 +715,11 @@ view StoreAgent {
|
|||||||
storeListingVersionId String
|
storeListingVersionId String
|
||||||
updated_at DateTime
|
updated_at DateTime
|
||||||
|
|
||||||
slug String
|
slug String
|
||||||
agent_name String
|
agent_name String
|
||||||
agent_video String?
|
agent_video String?
|
||||||
agent_output_demo String?
|
agent_output_demo String?
|
||||||
agent_image String[]
|
agent_image String[]
|
||||||
|
|
||||||
featured Boolean @default(false)
|
featured Boolean @default(false)
|
||||||
creator_username String?
|
creator_username String?
|
||||||
@@ -842,14 +848,14 @@ model StoreListingVersion {
|
|||||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||||
|
|
||||||
// Content fields
|
// Content fields
|
||||||
name String
|
name String
|
||||||
subHeading String
|
subHeading String
|
||||||
videoUrl String?
|
videoUrl String?
|
||||||
agentOutputDemoUrl String?
|
agentOutputDemoUrl String?
|
||||||
imageUrls String[]
|
imageUrls String[]
|
||||||
description String
|
description String
|
||||||
instructions String?
|
instructions String?
|
||||||
categories String[]
|
categories String[]
|
||||||
|
|
||||||
isFeatured Boolean @default(false)
|
isFeatured Boolean @default(false)
|
||||||
|
|
||||||
@@ -969,226 +975,3 @@ enum APIKeyStatus {
|
|||||||
REVOKED
|
REVOKED
|
||||||
SUSPENDED
|
SUSPENDED
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// OAuth Provider & Credential Broker Models
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
enum OAuthClientStatus {
|
|
||||||
ACTIVE
|
|
||||||
SUSPENDED
|
|
||||||
}
|
|
||||||
|
|
||||||
enum CredentialGrantPermission {
|
|
||||||
USE // Can use credential for agent execution
|
|
||||||
DELETE // Can delete the credential
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Client - Registered external applications
|
|
||||||
model OAuthClient {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
// Client identification
|
|
||||||
clientId String @unique // Public identifier (e.g., "app_abc123")
|
|
||||||
clientSecretHash String? // Hashed (null for public clients)
|
|
||||||
clientSecretSalt String?
|
|
||||||
clientType String // "public" or "confidential"
|
|
||||||
|
|
||||||
// Metadata (shown on consent screen)
|
|
||||||
name String
|
|
||||||
description String?
|
|
||||||
logoUrl String?
|
|
||||||
homepageUrl String?
|
|
||||||
privacyPolicyUrl String?
|
|
||||||
termsOfServiceUrl String?
|
|
||||||
|
|
||||||
// Configuration
|
|
||||||
redirectUris String[]
|
|
||||||
allowedScopes String[]
|
|
||||||
webhookDomains String[] // For webhook URL validation
|
|
||||||
webhookSecret String? // Secret for HMAC signing webhooks
|
|
||||||
|
|
||||||
// Security
|
|
||||||
requirePkce Boolean @default(true)
|
|
||||||
tokenLifetimeSecs Int @default(3600)
|
|
||||||
refreshTokenLifetimeSecs Int @default(2592000) // 30 days
|
|
||||||
|
|
||||||
// Status
|
|
||||||
status OAuthClientStatus @default(ACTIVE)
|
|
||||||
|
|
||||||
// Owner
|
|
||||||
ownerId String
|
|
||||||
Owner User @relation("OAuthClientOwner", fields: [ownerId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
// Relations
|
|
||||||
Authorizations OAuthAuthorization[]
|
|
||||||
AuthorizationCodes OAuthAuthorizationCode[]
|
|
||||||
AccessTokens OAuthAccessToken[]
|
|
||||||
RefreshTokens OAuthRefreshToken[]
|
|
||||||
CredentialGrants CredentialGrant[]
|
|
||||||
|
|
||||||
@@index([clientId])
|
|
||||||
@@index([ownerId])
|
|
||||||
@@index([status])
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Authorization - User consent record
|
|
||||||
model OAuthAuthorization {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
userId String
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
clientId String
|
|
||||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
scopes String[]
|
|
||||||
revokedAt DateTime?
|
|
||||||
|
|
||||||
@@unique([userId, clientId])
|
|
||||||
@@index([userId])
|
|
||||||
@@index([clientId])
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Authorization Code - Short-lived, single-use
|
|
||||||
model OAuthAuthorizationCode {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
|
|
||||||
codeHash String @unique
|
|
||||||
|
|
||||||
userId String
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
clientId String
|
|
||||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
redirectUri String
|
|
||||||
scopes String[]
|
|
||||||
nonce String? // OIDC nonce
|
|
||||||
|
|
||||||
// PKCE
|
|
||||||
codeChallenge String
|
|
||||||
codeChallengeMethod String @default("S256")
|
|
||||||
|
|
||||||
expiresAt DateTime // 10 minutes
|
|
||||||
usedAt DateTime?
|
|
||||||
|
|
||||||
@@index([codeHash])
|
|
||||||
@@index([expiresAt])
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Access Token
|
|
||||||
model OAuthAccessToken {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
|
|
||||||
tokenHash String @unique // SHA256 of token
|
|
||||||
|
|
||||||
userId String
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
clientId String
|
|
||||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
scopes String[]
|
|
||||||
expiresAt DateTime
|
|
||||||
revokedAt DateTime?
|
|
||||||
lastUsedAt DateTime?
|
|
||||||
|
|
||||||
@@index([tokenHash])
|
|
||||||
@@index([userId, clientId])
|
|
||||||
@@index([expiresAt])
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Refresh Token
|
|
||||||
model OAuthRefreshToken {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
|
|
||||||
tokenHash String @unique
|
|
||||||
|
|
||||||
userId String
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
clientId String
|
|
||||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
scopes String[]
|
|
||||||
expiresAt DateTime
|
|
||||||
revokedAt DateTime?
|
|
||||||
|
|
||||||
@@index([tokenHash])
|
|
||||||
@@index([expiresAt])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Credential Grant - Links external app to user's credential with scoped access
|
|
||||||
model CredentialGrant {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
updatedAt DateTime @updatedAt
|
|
||||||
|
|
||||||
userId String
|
|
||||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
clientId String
|
|
||||||
Client OAuthClient @relation(fields: [clientId], references: [id], onDelete: Cascade)
|
|
||||||
|
|
||||||
credentialId String // Reference to credential in User.integrations
|
|
||||||
provider String
|
|
||||||
|
|
||||||
// Fine-grained integration scopes (e.g., "google:gmail.readonly")
|
|
||||||
grantedScopes String[]
|
|
||||||
|
|
||||||
// Permissions for the credential itself
|
|
||||||
permissions CredentialGrantPermission[]
|
|
||||||
|
|
||||||
expiresAt DateTime?
|
|
||||||
revokedAt DateTime?
|
|
||||||
lastUsedAt DateTime?
|
|
||||||
|
|
||||||
@@unique([userId, clientId, credentialId])
|
|
||||||
@@index([userId, clientId])
|
|
||||||
@@index([clientId])
|
|
||||||
}
|
|
||||||
|
|
||||||
// OAuth Audit Log
|
|
||||||
model OAuthAuditLog {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
|
|
||||||
eventType String // e.g., "token.issued", "grant.created"
|
|
||||||
|
|
||||||
userId String?
|
|
||||||
clientId String?
|
|
||||||
grantId String?
|
|
||||||
|
|
||||||
ipAddress String?
|
|
||||||
userAgent String?
|
|
||||||
|
|
||||||
details Json @default("{}")
|
|
||||||
|
|
||||||
@@index([createdAt])
|
|
||||||
@@index([eventType])
|
|
||||||
@@index([userId])
|
|
||||||
@@index([clientId])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execution Webhook - Webhook registration for external API executions
|
|
||||||
model ExecutionWebhook {
|
|
||||||
id String @id @default(uuid())
|
|
||||||
createdAt DateTime @default(now())
|
|
||||||
|
|
||||||
executionId String // The graph execution ID
|
|
||||||
webhookUrl String // URL to send notifications to
|
|
||||||
clientId String // The OAuth client database ID
|
|
||||||
userId String // The user who started the execution
|
|
||||||
secret String? // Optional webhook secret for HMAC signing
|
|
||||||
|
|
||||||
@@index([executionId])
|
|
||||||
@@index([clientId])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import random
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
from prisma.types import AgentBlockCreateInput
|
|
||||||
|
|
||||||
from backend.data.api_key import create_api_key
|
from backend.data.api_key import create_api_key
|
||||||
from backend.data.credit import get_user_credit_model
|
from backend.data.credit import get_user_credit_model
|
||||||
@@ -178,12 +177,12 @@ class TestDataCreator:
|
|||||||
for block in blocks_to_create:
|
for block in blocks_to_create:
|
||||||
try:
|
try:
|
||||||
await prisma.agentblock.create(
|
await prisma.agentblock.create(
|
||||||
data=AgentBlockCreateInput(
|
data={
|
||||||
id=block.id,
|
"id": block.id,
|
||||||
name=block.name,
|
"name": block.name,
|
||||||
inputSchema="{}",
|
"inputSchema": "{}",
|
||||||
outputSchema="{}",
|
"outputSchema": "{}",
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating block {block.name}: {e}")
|
print(f"Error creating block {block.name}: {e}")
|
||||||
|
|||||||
@@ -30,19 +30,13 @@ from prisma.types import (
|
|||||||
AgentGraphCreateInput,
|
AgentGraphCreateInput,
|
||||||
AgentNodeCreateInput,
|
AgentNodeCreateInput,
|
||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
AgentPresetCreateInput,
|
|
||||||
AnalyticsDetailsCreateInput,
|
AnalyticsDetailsCreateInput,
|
||||||
AnalyticsMetricsCreateInput,
|
AnalyticsMetricsCreateInput,
|
||||||
APIKeyCreateInput,
|
|
||||||
CreditTransactionCreateInput,
|
CreditTransactionCreateInput,
|
||||||
IntegrationWebhookCreateInput,
|
IntegrationWebhookCreateInput,
|
||||||
LibraryAgentCreateInput,
|
|
||||||
ProfileCreateInput,
|
ProfileCreateInput,
|
||||||
StoreListingCreateInput,
|
|
||||||
StoreListingReviewCreateInput,
|
StoreListingReviewCreateInput,
|
||||||
StoreListingVersionCreateInput,
|
|
||||||
UserCreateInput,
|
UserCreateInput,
|
||||||
UserOnboardingCreateInput,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
faker = Faker()
|
faker = Faker()
|
||||||
@@ -178,14 +172,14 @@ async def main():
|
|||||||
for _ in range(num_presets): # Create 1 AgentPreset per user
|
for _ in range(num_presets): # Create 1 AgentPreset per user
|
||||||
graph = random.choice(agent_graphs)
|
graph = random.choice(agent_graphs)
|
||||||
preset = await db.agentpreset.create(
|
preset = await db.agentpreset.create(
|
||||||
data=AgentPresetCreateInput(
|
data={
|
||||||
name=faker.sentence(nb_words=3),
|
"name": faker.sentence(nb_words=3),
|
||||||
description=faker.text(max_nb_chars=200),
|
"description": faker.text(max_nb_chars=200),
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
agentGraphId=graph.id,
|
"agentGraphId": graph.id,
|
||||||
agentGraphVersion=graph.version,
|
"agentGraphVersion": graph.version,
|
||||||
isActive=True,
|
"isActive": True,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
agent_presets.append(preset)
|
agent_presets.append(preset)
|
||||||
|
|
||||||
@@ -226,18 +220,18 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
library_agent = await db.libraryagent.create(
|
library_agent = await db.libraryagent.create(
|
||||||
data=LibraryAgentCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
agentGraphId=graph.id,
|
"agentGraphId": graph.id,
|
||||||
agentGraphVersion=graph.version,
|
"agentGraphVersion": graph.version,
|
||||||
creatorId=creator_profile.id if creator_profile else None,
|
"creatorId": creator_profile.id if creator_profile else None,
|
||||||
imageUrl=get_image() if random.random() < 0.5 else None,
|
"imageUrl": get_image() if random.random() < 0.5 else None,
|
||||||
useGraphIsActiveVersion=random.choice([True, False]),
|
"useGraphIsActiveVersion": random.choice([True, False]),
|
||||||
isFavorite=random.choice([True, False]),
|
"isFavorite": random.choice([True, False]),
|
||||||
isCreatedByUser=random.choice([True, False]),
|
"isCreatedByUser": random.choice([True, False]),
|
||||||
isArchived=random.choice([True, False]),
|
"isArchived": random.choice([True, False]),
|
||||||
isDeleted=random.choice([True, False]),
|
"isDeleted": random.choice([True, False]),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
library_agents.append(library_agent)
|
library_agents.append(library_agent)
|
||||||
|
|
||||||
@@ -398,13 +392,13 @@ async def main():
|
|||||||
user = random.choice(users)
|
user = random.choice(users)
|
||||||
slug = faker.slug()
|
slug = faker.slug()
|
||||||
listing = await db.storelisting.create(
|
listing = await db.storelisting.create(
|
||||||
data=StoreListingCreateInput(
|
data={
|
||||||
agentGraphId=graph.id,
|
"agentGraphId": graph.id,
|
||||||
agentGraphVersion=graph.version,
|
"agentGraphVersion": graph.version,
|
||||||
owningUserId=user.id,
|
"owningUserId": user.id,
|
||||||
hasApprovedVersion=random.choice([True, False]),
|
"hasApprovedVersion": random.choice([True, False]),
|
||||||
slug=slug,
|
"slug": slug,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
store_listings.append(listing)
|
store_listings.append(listing)
|
||||||
|
|
||||||
@@ -414,26 +408,26 @@ async def main():
|
|||||||
for listing in store_listings:
|
for listing in store_listings:
|
||||||
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
graph = [g for g in agent_graphs if g.id == listing.agentGraphId][0]
|
||||||
version = await db.storelistingversion.create(
|
version = await db.storelistingversion.create(
|
||||||
data=StoreListingVersionCreateInput(
|
data={
|
||||||
agentGraphId=graph.id,
|
"agentGraphId": graph.id,
|
||||||
agentGraphVersion=graph.version,
|
"agentGraphVersion": graph.version,
|
||||||
name=graph.name or faker.sentence(nb_words=3),
|
"name": graph.name or faker.sentence(nb_words=3),
|
||||||
subHeading=faker.sentence(),
|
"subHeading": faker.sentence(),
|
||||||
videoUrl=get_video_url() if random.random() < 0.3 else None,
|
"videoUrl": get_video_url() if random.random() < 0.3 else None,
|
||||||
imageUrls=[get_image() for _ in range(3)],
|
"imageUrls": [get_image() for _ in range(3)],
|
||||||
description=faker.text(),
|
"description": faker.text(),
|
||||||
categories=[faker.word() for _ in range(3)],
|
"categories": [faker.word() for _ in range(3)],
|
||||||
isFeatured=random.choice([True, False]),
|
"isFeatured": random.choice([True, False]),
|
||||||
isAvailable=True,
|
"isAvailable": True,
|
||||||
storeListingId=listing.id,
|
"storeListingId": listing.id,
|
||||||
submissionStatus=random.choice(
|
"submissionStatus": random.choice(
|
||||||
[
|
[
|
||||||
prisma.enums.SubmissionStatus.PENDING,
|
prisma.enums.SubmissionStatus.PENDING,
|
||||||
prisma.enums.SubmissionStatus.APPROVED,
|
prisma.enums.SubmissionStatus.APPROVED,
|
||||||
prisma.enums.SubmissionStatus.REJECTED,
|
prisma.enums.SubmissionStatus.REJECTED,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
store_listing_versions.append(version)
|
store_listing_versions.append(version)
|
||||||
|
|
||||||
@@ -475,47 +469,51 @@ async def main():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await db.useronboarding.create(
|
await db.useronboarding.create(
|
||||||
data=UserOnboardingCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
completedSteps=completed_steps,
|
"completedSteps": completed_steps,
|
||||||
walletShown=random.choice([True, False]),
|
"walletShown": random.choice([True, False]),
|
||||||
notified=(
|
"notified": (
|
||||||
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
random.sample(completed_steps, k=min(3, len(completed_steps)))
|
||||||
if completed_steps
|
if completed_steps
|
||||||
else []
|
else []
|
||||||
),
|
),
|
||||||
rewardedFor=(
|
"rewardedFor": (
|
||||||
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
random.sample(completed_steps, k=min(2, len(completed_steps)))
|
||||||
if completed_steps
|
if completed_steps
|
||||||
else []
|
else []
|
||||||
),
|
),
|
||||||
usageReason=(
|
"usageReason": (
|
||||||
random.choice(["personal", "business", "research", "learning"])
|
random.choice(["personal", "business", "research", "learning"])
|
||||||
if random.random() < 0.7
|
if random.random() < 0.7
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
integrations=random.sample(
|
"integrations": random.sample(
|
||||||
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
["github", "google", "discord", "slack"], k=random.randint(0, 2)
|
||||||
),
|
),
|
||||||
otherIntegrations=(faker.word() if random.random() < 0.2 else None),
|
"otherIntegrations": (
|
||||||
selectedStoreListingVersionId=(
|
faker.word() if random.random() < 0.2 else None
|
||||||
|
),
|
||||||
|
"selectedStoreListingVersionId": (
|
||||||
random.choice(store_listing_versions).id
|
random.choice(store_listing_versions).id
|
||||||
if store_listing_versions and random.random() < 0.5
|
if store_listing_versions and random.random() < 0.5
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
onboardingAgentExecutionId=(
|
"onboardingAgentExecutionId": (
|
||||||
random.choice(agent_graph_executions).id
|
random.choice(agent_graph_executions).id
|
||||||
if agent_graph_executions and random.random() < 0.3
|
if agent_graph_executions and random.random() < 0.3
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
agentRuns=random.randint(0, 10),
|
"agentRuns": random.randint(0, 10),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating onboarding for user {user.id}: {e}")
|
print(f"Error creating onboarding for user {user.id}: {e}")
|
||||||
# Try simpler version
|
# Try simpler version
|
||||||
await db.useronboarding.create(
|
await db.useronboarding.create(
|
||||||
data=UserOnboardingCreateInput(userId=user.id)
|
data={
|
||||||
|
"userId": user.id,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert IntegrationWebhooks for some users
|
# Insert IntegrationWebhooks for some users
|
||||||
@@ -546,20 +544,20 @@ async def main():
|
|||||||
for user in users:
|
for user in users:
|
||||||
api_key = APIKeySmith().generate_key()
|
api_key = APIKeySmith().generate_key()
|
||||||
await db.apikey.create(
|
await db.apikey.create(
|
||||||
data=APIKeyCreateInput(
|
data={
|
||||||
name=faker.word(),
|
"name": faker.word(),
|
||||||
head=api_key.head,
|
"head": api_key.head,
|
||||||
tail=api_key.tail,
|
"tail": api_key.tail,
|
||||||
hash=api_key.hash,
|
"hash": api_key.hash,
|
||||||
salt=api_key.salt,
|
"salt": api_key.salt,
|
||||||
status=prisma.enums.APIKeyStatus.ACTIVE,
|
"status": prisma.enums.APIKeyStatus.ACTIVE,
|
||||||
permissions=[
|
"permissions": [
|
||||||
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
||||||
prisma.enums.APIKeyPermission.READ_GRAPH,
|
prisma.enums.APIKeyPermission.READ_GRAPH,
|
||||||
],
|
],
|
||||||
description=faker.text(),
|
"description": faker.text(),
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Refresh materialized views
|
# Refresh materialized views
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from datetime import datetime, timedelta
|
|||||||
import prisma.enums
|
import prisma.enums
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
from prisma import Json, Prisma
|
from prisma import Json, Prisma
|
||||||
from prisma.types import CreditTransactionCreateInput, StoreListingReviewCreateInput
|
|
||||||
|
|
||||||
faker = Faker()
|
faker = Faker()
|
||||||
|
|
||||||
@@ -167,16 +166,16 @@ async def main():
|
|||||||
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
score = random.choices([1, 2, 3, 4, 5], weights=[5, 10, 20, 40, 25])[0]
|
||||||
|
|
||||||
await db.storelistingreview.create(
|
await db.storelistingreview.create(
|
||||||
data=StoreListingReviewCreateInput(
|
data={
|
||||||
storeListingVersionId=version.id,
|
"storeListingVersionId": version.id,
|
||||||
reviewByUserId=reviewer.id,
|
"reviewByUserId": reviewer.id,
|
||||||
score=score,
|
"score": score,
|
||||||
comments=(
|
"comments": (
|
||||||
faker.text(max_nb_chars=200)
|
faker.text(max_nb_chars=200)
|
||||||
if random.random() < 0.7
|
if random.random() < 0.7
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
new_reviews_count += 1
|
new_reviews_count += 1
|
||||||
|
|
||||||
@@ -245,17 +244,17 @@ async def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
await db.credittransaction.create(
|
await db.credittransaction.create(
|
||||||
data=CreditTransactionCreateInput(
|
data={
|
||||||
userId=user.id,
|
"userId": user.id,
|
||||||
amount=amount,
|
"amount": amount,
|
||||||
type=transaction_type,
|
"type": transaction_type,
|
||||||
metadata=Json(
|
"metadata": Json(
|
||||||
{
|
{
|
||||||
"source": "test_updater",
|
"source": "test_updater",
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
transaction_count += 1
|
transaction_count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,7 @@
|
|||||||
"lodash": "4.17.21",
|
"lodash": "4.17.21",
|
||||||
"lucide-react": "0.552.0",
|
"lucide-react": "0.552.0",
|
||||||
"moment": "2.30.1",
|
"moment": "2.30.1",
|
||||||
"next": "15.4.8",
|
"next": "15.4.10",
|
||||||
"next-themes": "0.4.6",
|
"next-themes": "0.4.6",
|
||||||
"nuqs": "2.7.2",
|
"nuqs": "2.7.2",
|
||||||
"party-js": "2.2.0",
|
"party-js": "2.2.0",
|
||||||
|
|||||||
60
autogpt_platform/frontend/pnpm-lock.yaml
generated
60
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -16,7 +16,7 @@ importers:
|
|||||||
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
|
version: 5.2.2(react-hook-form@7.66.0(react@18.3.1))
|
||||||
'@next/third-parties':
|
'@next/third-parties':
|
||||||
specifier: 15.4.6
|
specifier: 15.4.6
|
||||||
version: 15.4.6(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
version: 15.4.6(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||||
'@phosphor-icons/react':
|
'@phosphor-icons/react':
|
||||||
specifier: 2.1.10
|
specifier: 2.1.10
|
||||||
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 2.1.10(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
@@ -88,7 +88,7 @@ importers:
|
|||||||
version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1))
|
version: 5.24.13(@rjsf/utils@5.24.13(react@18.3.1))
|
||||||
'@sentry/nextjs':
|
'@sentry/nextjs':
|
||||||
specifier: 10.27.0
|
specifier: 10.27.0
|
||||||
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))
|
version: 10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))
|
||||||
'@supabase/ssr':
|
'@supabase/ssr':
|
||||||
specifier: 0.7.0
|
specifier: 0.7.0
|
||||||
version: 0.7.0(@supabase/supabase-js@2.78.0)
|
version: 0.7.0(@supabase/supabase-js@2.78.0)
|
||||||
@@ -106,10 +106,10 @@ importers:
|
|||||||
version: 0.2.4
|
version: 0.2.4
|
||||||
'@vercel/analytics':
|
'@vercel/analytics':
|
||||||
specifier: 1.5.0
|
specifier: 1.5.0
|
||||||
version: 1.5.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
version: 1.5.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||||
'@vercel/speed-insights':
|
'@vercel/speed-insights':
|
||||||
specifier: 1.2.0
|
specifier: 1.2.0
|
||||||
version: 1.2.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
version: 1.2.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||||
'@xyflow/react':
|
'@xyflow/react':
|
||||||
specifier: 12.9.2
|
specifier: 12.9.2
|
||||||
version: 12.9.2(@types/react@18.3.17)(immer@10.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 12.9.2(@types/react@18.3.17)(immer@10.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
@@ -148,7 +148,7 @@ importers:
|
|||||||
version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 12.23.24(@emotion/is-prop-valid@1.2.2)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
geist:
|
geist:
|
||||||
specifier: 1.5.1
|
specifier: 1.5.1
|
||||||
version: 1.5.1(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
|
version: 1.5.1(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))
|
||||||
highlight.js:
|
highlight.js:
|
||||||
specifier: 11.11.1
|
specifier: 11.11.1
|
||||||
version: 11.11.1
|
version: 11.11.1
|
||||||
@@ -171,14 +171,14 @@ importers:
|
|||||||
specifier: 2.30.1
|
specifier: 2.30.1
|
||||||
version: 2.30.1
|
version: 2.30.1
|
||||||
next:
|
next:
|
||||||
specifier: 15.4.8
|
specifier: 15.4.10
|
||||||
version: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
next-themes:
|
next-themes:
|
||||||
specifier: 0.4.6
|
specifier: 0.4.6
|
||||||
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 0.4.6(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
nuqs:
|
nuqs:
|
||||||
specifier: 2.7.2
|
specifier: 2.7.2
|
||||||
version: 2.7.2(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
version: 2.7.2(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)
|
||||||
party-js:
|
party-js:
|
||||||
specifier: 2.2.0
|
specifier: 2.2.0
|
||||||
version: 2.2.0
|
version: 2.2.0
|
||||||
@@ -284,7 +284,7 @@ importers:
|
|||||||
version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
|
version: 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))
|
||||||
'@storybook/nextjs':
|
'@storybook/nextjs':
|
||||||
specifier: 9.1.5
|
specifier: 9.1.5
|
||||||
version: 9.1.5(esbuild@0.25.9)(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))
|
version: 9.1.5(esbuild@0.25.9)(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))
|
||||||
'@tanstack/eslint-plugin-query':
|
'@tanstack/eslint-plugin-query':
|
||||||
specifier: 5.91.2
|
specifier: 5.91.2
|
||||||
version: 5.91.2(eslint@8.57.1)(typescript@5.9.3)
|
version: 5.91.2(eslint@8.57.1)(typescript@5.9.3)
|
||||||
@@ -1602,8 +1602,8 @@ packages:
|
|||||||
'@neoconfetti/react@1.0.0':
|
'@neoconfetti/react@1.0.0':
|
||||||
resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==}
|
resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==}
|
||||||
|
|
||||||
'@next/env@15.4.8':
|
'@next/env@15.4.10':
|
||||||
resolution: {integrity: sha512-LydLa2MDI1NMrOFSkO54mTc8iIHSttj6R6dthITky9ylXV2gCGi0bHQjVCtLGRshdRPjyh2kXbxJukDtBWQZtQ==}
|
resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==}
|
||||||
|
|
||||||
'@next/eslint-plugin-next@15.5.2':
|
'@next/eslint-plugin-next@15.5.2':
|
||||||
resolution: {integrity: sha512-lkLrRVxcftuOsJNhWatf1P2hNVfh98k/omQHrCEPPriUypR6RcS13IvLdIrEvkm9AH2Nu2YpR5vLqBuy6twH3Q==}
|
resolution: {integrity: sha512-lkLrRVxcftuOsJNhWatf1P2hNVfh98k/omQHrCEPPriUypR6RcS13IvLdIrEvkm9AH2Nu2YpR5vLqBuy6twH3Q==}
|
||||||
@@ -5920,8 +5920,8 @@ packages:
|
|||||||
react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
||||||
react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc
|
||||||
|
|
||||||
next@15.4.8:
|
next@15.4.10:
|
||||||
resolution: {integrity: sha512-jwOXTz/bo0Pvlf20FSb6VXVeWRssA2vbvq9SdrOPEg9x8E1B27C2rQtvriAn600o9hH61kjrVRexEffv3JybuA==}
|
resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==}
|
||||||
engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0}
|
engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@@ -9003,7 +9003,7 @@ snapshots:
|
|||||||
|
|
||||||
'@neoconfetti/react@1.0.0': {}
|
'@neoconfetti/react@1.0.0': {}
|
||||||
|
|
||||||
'@next/env@15.4.8': {}
|
'@next/env@15.4.10': {}
|
||||||
|
|
||||||
'@next/eslint-plugin-next@15.5.2':
|
'@next/eslint-plugin-next@15.5.2':
|
||||||
dependencies:
|
dependencies:
|
||||||
@@ -9033,9 +9033,9 @@ snapshots:
|
|||||||
'@next/swc-win32-x64-msvc@15.4.8':
|
'@next/swc-win32-x64-msvc@15.4.8':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
'@next/third-parties@15.4.6(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
'@next/third-parties@15.4.6(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||||
dependencies:
|
dependencies:
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
third-party-capital: 1.0.20
|
third-party-capital: 1.0.20
|
||||||
|
|
||||||
@@ -10267,7 +10267,7 @@ snapshots:
|
|||||||
|
|
||||||
'@sentry/core@10.27.0': {}
|
'@sentry/core@10.27.0': {}
|
||||||
|
|
||||||
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))':
|
'@sentry/nextjs@10.27.0(@opentelemetry/context-async-hooks@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/core@2.2.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.2.0(@opentelemetry/api@1.9.0))(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)(webpack@5.101.3(esbuild@0.25.9))':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@opentelemetry/api': 1.9.0
|
'@opentelemetry/api': 1.9.0
|
||||||
'@opentelemetry/semantic-conventions': 1.37.0
|
'@opentelemetry/semantic-conventions': 1.37.0
|
||||||
@@ -10280,7 +10280,7 @@ snapshots:
|
|||||||
'@sentry/react': 10.27.0(react@18.3.1)
|
'@sentry/react': 10.27.0(react@18.3.1)
|
||||||
'@sentry/vercel-edge': 10.27.0
|
'@sentry/vercel-edge': 10.27.0
|
||||||
'@sentry/webpack-plugin': 4.3.0(webpack@5.101.3(esbuild@0.25.9))
|
'@sentry/webpack-plugin': 4.3.0(webpack@5.101.3(esbuild@0.25.9))
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
resolve: 1.22.8
|
resolve: 1.22.8
|
||||||
rollup: 4.52.2
|
rollup: 4.52.2
|
||||||
stacktrace-parser: 0.1.11
|
stacktrace-parser: 0.1.11
|
||||||
@@ -10642,7 +10642,7 @@ snapshots:
|
|||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
react-dom: 18.3.1(react@18.3.1)
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
|
||||||
'@storybook/nextjs@9.1.5(esbuild@0.25.9)(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))':
|
'@storybook/nextjs@9.1.5(esbuild@0.25.9)(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(prettier@3.6.2))(type-fest@4.41.0)(typescript@5.9.3)(webpack-hot-middleware@2.26.1)(webpack@5.101.3(esbuild@0.25.9))':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@babel/core': 7.28.4
|
'@babel/core': 7.28.4
|
||||||
'@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.4)
|
'@babel/plugin-syntax-bigint': 7.8.3(@babel/core@7.28.4)
|
||||||
@@ -10666,7 +10666,7 @@ snapshots:
|
|||||||
css-loader: 6.11.0(webpack@5.101.3(esbuild@0.25.9))
|
css-loader: 6.11.0(webpack@5.101.3(esbuild@0.25.9))
|
||||||
image-size: 2.0.2
|
image-size: 2.0.2
|
||||||
loader-utils: 3.3.1
|
loader-utils: 3.3.1
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
node-polyfill-webpack-plugin: 2.0.1(webpack@5.101.3(esbuild@0.25.9))
|
node-polyfill-webpack-plugin: 2.0.1(webpack@5.101.3(esbuild@0.25.9))
|
||||||
postcss: 8.5.6
|
postcss: 8.5.6
|
||||||
postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.101.3(esbuild@0.25.9))
|
postcss-loader: 8.2.0(postcss@8.5.6)(typescript@5.9.3)(webpack@5.101.3(esbuild@0.25.9))
|
||||||
@@ -11271,14 +11271,14 @@ snapshots:
|
|||||||
'@unrs/resolver-binding-win32-x64-msvc@1.11.1':
|
'@unrs/resolver-binding-win32-x64-msvc@1.11.1':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
'@vercel/analytics@1.5.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
'@vercel/analytics@1.5.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
|
|
||||||
'@vercel/speed-insights@1.2.0(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
'@vercel/speed-insights@1.2.0(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1)':
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
|
|
||||||
'@vitest/expect@3.2.4':
|
'@vitest/expect@3.2.4':
|
||||||
@@ -12954,9 +12954,9 @@ snapshots:
|
|||||||
|
|
||||||
functions-have-names@1.2.3: {}
|
functions-have-names@1.2.3: {}
|
||||||
|
|
||||||
geist@1.5.1(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
|
geist@1.5.1(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)):
|
||||||
dependencies:
|
dependencies:
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
|
||||||
gensync@1.0.0-beta.2: {}
|
gensync@1.0.0-beta.2: {}
|
||||||
|
|
||||||
@@ -14226,9 +14226,9 @@ snapshots:
|
|||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
react-dom: 18.3.1(react@18.3.1)
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
|
||||||
next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
'@next/env': 15.4.8
|
'@next/env': 15.4.10
|
||||||
'@swc/helpers': 0.5.15
|
'@swc/helpers': 0.5.15
|
||||||
caniuse-lite: 1.0.30001741
|
caniuse-lite: 1.0.30001741
|
||||||
postcss: 8.4.31
|
postcss: 8.4.31
|
||||||
@@ -14321,12 +14321,12 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
boolbase: 1.0.0
|
boolbase: 1.0.0
|
||||||
|
|
||||||
nuqs@2.7.2(next@15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
|
nuqs@2.7.2(next@15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react@18.3.1):
|
||||||
dependencies:
|
dependencies:
|
||||||
'@standard-schema/spec': 1.0.0
|
'@standard-schema/spec': 1.0.0
|
||||||
react: 18.3.1
|
react: 18.3.1
|
||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
next: 15.4.8(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
next: 15.4.10(@babel/core@7.28.4)(@opentelemetry/api@1.9.0)(@playwright/test@1.56.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
|
||||||
oas-kit-common@1.0.8:
|
oas-kit-common@1.0.8:
|
||||||
dependencies:
|
dependencies:
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import { shouldShowOnboarding } from "@/app/api/helpers";
|
|||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
const { searchParams, origin } = new URL(request.url);
|
const { searchParams, origin } = new URL(request.url);
|
||||||
const code = searchParams.get("code");
|
const code = searchParams.get("code");
|
||||||
const oauthSession = searchParams.get("oauth_session");
|
|
||||||
const connectSession = searchParams.get("connect_session");
|
|
||||||
|
|
||||||
let next = "/marketplace";
|
let next = "/marketplace";
|
||||||
|
|
||||||
@@ -27,22 +25,6 @@ export async function GET(request: Request) {
|
|||||||
const api = new BackendAPI();
|
const api = new BackendAPI();
|
||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Handle oauth_session redirect - resume OAuth flow after login
|
|
||||||
// Redirect to a frontend page that will handle the OAuth resume with proper auth
|
|
||||||
if (oauthSession) {
|
|
||||||
return NextResponse.redirect(
|
|
||||||
`${origin}/auth/oauth-resume?session_id=${encodeURIComponent(oauthSession)}`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle connect_session redirect - resume connect flow after login
|
|
||||||
// Redirect to a frontend page that will handle the connect resume with proper auth
|
|
||||||
if (connectSession) {
|
|
||||||
return NextResponse.redirect(
|
|
||||||
`${origin}/auth/connect-resume?session_id=${encodeURIComponent(connectSession)}`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (await shouldShowOnboarding()) {
|
if (await shouldShowOnboarding()) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
|
|||||||
@@ -1,400 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useEffect, useState, useRef, useCallback } from "react";
|
|
||||||
import { useSearchParams } from "next/navigation";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
|
||||||
|
|
||||||
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
|
|
||||||
const attemptedSessions = new Set<string>();
|
|
||||||
|
|
||||||
interface ScopeInfo {
|
|
||||||
scope: string;
|
|
||||||
description: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface CredentialInfo {
|
|
||||||
id: string;
|
|
||||||
title: string;
|
|
||||||
username: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ClientInfo {
|
|
||||||
name: string;
|
|
||||||
logo_url: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ConnectData {
|
|
||||||
connect_token: string;
|
|
||||||
client: ClientInfo;
|
|
||||||
provider: string;
|
|
||||||
scopes: ScopeInfo[];
|
|
||||||
credentials: CredentialInfo[];
|
|
||||||
action_url: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ErrorData {
|
|
||||||
error: string;
|
|
||||||
error_description: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResumeResponse = ConnectData | ErrorData;
|
|
||||||
|
|
||||||
function isConnectData(data: ResumeResponse): data is ConnectData {
|
|
||||||
return "connect_token" in data;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isErrorData(data: ResumeResponse): data is ErrorData {
|
|
||||||
return "error" in data;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connect Consent Form Component
|
|
||||||
*
|
|
||||||
* Renders a proper React component for the integration connect consent form
|
|
||||||
*/
|
|
||||||
function ConnectForm({
|
|
||||||
client,
|
|
||||||
provider,
|
|
||||||
scopes,
|
|
||||||
credentials,
|
|
||||||
connectToken,
|
|
||||||
actionUrl,
|
|
||||||
}: {
|
|
||||||
client: ClientInfo;
|
|
||||||
provider: string;
|
|
||||||
scopes: ScopeInfo[];
|
|
||||||
credentials: CredentialInfo[];
|
|
||||||
connectToken: string;
|
|
||||||
actionUrl: string;
|
|
||||||
}) {
|
|
||||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
|
||||||
const [selectedCredential, setSelectedCredential] = useState<string>(
|
|
||||||
credentials.length > 0 ? credentials[0].id : "",
|
|
||||||
);
|
|
||||||
|
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
|
||||||
const backendOrigin = backendUrl
|
|
||||||
? new URL(backendUrl).origin
|
|
||||||
: "http://localhost:8006";
|
|
||||||
|
|
||||||
const fullActionUrl = `${backendOrigin}${actionUrl}`;
|
|
||||||
|
|
||||||
function handleSubmit() {
|
|
||||||
setIsSubmitting(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
|
|
||||||
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
|
|
||||||
{/* Header */}
|
|
||||||
<div className="mb-6 text-center">
|
|
||||||
<h1 className="text-xl font-semibold text-zinc-100">
|
|
||||||
Connect{" "}
|
|
||||||
<span className="rounded bg-zinc-700 px-2 py-1 text-sm capitalize">
|
|
||||||
{provider}
|
|
||||||
</span>
|
|
||||||
</h1>
|
|
||||||
<p className="mt-2 text-sm text-zinc-400">
|
|
||||||
<span className="font-semibold text-cyan-400">{client.name}</span>{" "}
|
|
||||||
wants to use your {provider} integration
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Divider */}
|
|
||||||
<div className="my-6 h-px bg-zinc-700" />
|
|
||||||
|
|
||||||
{/* Scopes Section */}
|
|
||||||
<div className="mb-6">
|
|
||||||
<h2 className="mb-4 text-sm font-medium text-zinc-400">
|
|
||||||
This will allow {client.name} to:
|
|
||||||
</h2>
|
|
||||||
<div className="space-y-2">
|
|
||||||
{scopes.map((scope) => (
|
|
||||||
<div key={scope.scope} className="flex items-start gap-2 py-2">
|
|
||||||
<span className="flex-shrink-0 text-cyan-400">✓</span>
|
|
||||||
<span className="text-sm text-zinc-300">
|
|
||||||
{scope.description}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Divider */}
|
|
||||||
<div className="my-6 h-px bg-zinc-700" />
|
|
||||||
|
|
||||||
{/* Form */}
|
|
||||||
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
|
|
||||||
<input type="hidden" name="connect_token" value={connectToken} />
|
|
||||||
|
|
||||||
{/* Existing credentials selection */}
|
|
||||||
{credentials.length > 0 && (
|
|
||||||
<>
|
|
||||||
<h3 className="mb-3 text-sm font-medium text-zinc-400">
|
|
||||||
Select an existing credential:
|
|
||||||
</h3>
|
|
||||||
<div className="mb-4 space-y-2">
|
|
||||||
{credentials.map((cred) => (
|
|
||||||
<label
|
|
||||||
key={cred.id}
|
|
||||||
className={`flex cursor-pointer items-center gap-3 rounded-lg border p-3 transition-colors ${
|
|
||||||
selectedCredential === cred.id
|
|
||||||
? "border-cyan-400 bg-cyan-400/10"
|
|
||||||
: "border-zinc-700 hover:border-cyan-400/50"
|
|
||||||
}`}
|
|
||||||
>
|
|
||||||
<input
|
|
||||||
type="radio"
|
|
||||||
name="credential_id"
|
|
||||||
value={cred.id}
|
|
||||||
checked={selectedCredential === cred.id}
|
|
||||||
onChange={() => setSelectedCredential(cred.id)}
|
|
||||||
className="hidden"
|
|
||||||
/>
|
|
||||||
<div>
|
|
||||||
<div className="text-sm font-medium text-zinc-200">
|
|
||||||
{cred.title}
|
|
||||||
</div>
|
|
||||||
{cred.username && (
|
|
||||||
<div className="text-xs text-zinc-500">
|
|
||||||
{cred.username}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</label>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
<div className="my-4 h-px bg-zinc-700" />
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Connect new account */}
|
|
||||||
<div className="mb-4">
|
|
||||||
{credentials.length > 0 ? (
|
|
||||||
<h3 className="mb-3 text-sm font-medium text-zinc-400">
|
|
||||||
Or connect a new account:
|
|
||||||
</h3>
|
|
||||||
) : (
|
|
||||||
<p className="mb-3 text-sm text-zinc-400">
|
|
||||||
You don't have any {provider} credentials yet.
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
name="action"
|
|
||||||
value="connect_new"
|
|
||||||
disabled={isSubmitting}
|
|
||||||
className="w-full rounded-lg bg-blue-500 px-6 py-3 text-sm font-medium text-white transition-colors hover:bg-blue-400 disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
>
|
|
||||||
Connect {provider.charAt(0).toUpperCase() + provider.slice(1)}{" "}
|
|
||||||
Account
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Action buttons */}
|
|
||||||
<div className="flex gap-3">
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
name="action"
|
|
||||||
value="deny"
|
|
||||||
disabled={isSubmitting}
|
|
||||||
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
>
|
|
||||||
Cancel
|
|
||||||
</button>
|
|
||||||
{credentials.length > 0 && (
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
name="action"
|
|
||||||
value="approve"
|
|
||||||
disabled={isSubmitting}
|
|
||||||
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
>
|
|
||||||
{isSubmitting ? "Approving..." : "Approve"}
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connect Resume Page
|
|
||||||
*
|
|
||||||
* This page handles resuming the integration connect flow after a user logs in.
|
|
||||||
* It fetches the connect data from the backend via JSON API and renders the consent form.
|
|
||||||
*/
|
|
||||||
export default function ConnectResumePage() {
|
|
||||||
const searchParams = useSearchParams();
|
|
||||||
const sessionId = searchParams.get("session_id");
|
|
||||||
const { isUserLoading, refreshSession } = useSupabase();
|
|
||||||
|
|
||||||
const [connectData, setConnectData] = useState<ConnectData | null>(null);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
|
||||||
const retryCountRef = useRef(0);
|
|
||||||
const maxRetries = 5;
|
|
||||||
|
|
||||||
const resumeConnectFlow = useCallback(async () => {
|
|
||||||
if (!sessionId) {
|
|
||||||
setError(
|
|
||||||
"Missing session ID. Please start the connection process again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (attemptedSessions.has(sessionId)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isUserLoading) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
attemptedSessions.add(sessionId);
|
|
||||||
|
|
||||||
try {
|
|
||||||
let tokenResult = await getWebSocketToken();
|
|
||||||
let accessToken = tokenResult.token;
|
|
||||||
|
|
||||||
while (!accessToken && retryCountRef.current < maxRetries) {
|
|
||||||
retryCountRef.current += 1;
|
|
||||||
console.log(
|
|
||||||
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
|
|
||||||
);
|
|
||||||
await refreshSession();
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
|
||||||
tokenResult = await getWebSocketToken();
|
|
||||||
accessToken = tokenResult.token;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!accessToken) {
|
|
||||||
setError(
|
|
||||||
"Unable to retrieve authentication token. Please log in again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
|
||||||
if (!backendUrl) {
|
|
||||||
setError("Backend URL not configured.");
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let backendOrigin: string;
|
|
||||||
try {
|
|
||||||
const url = new URL(backendUrl);
|
|
||||||
backendOrigin = url.origin;
|
|
||||||
} catch {
|
|
||||||
setError("Invalid backend URL configuration.");
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(
|
|
||||||
`${backendOrigin}/connect/resume?session_id=${encodeURIComponent(sessionId)}`,
|
|
||||||
{
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${accessToken}`,
|
|
||||||
Accept: "application/json",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
const data: ResumeResponse = await response.json();
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
if (isErrorData(data)) {
|
|
||||||
setError(data.error_description || data.error);
|
|
||||||
} else {
|
|
||||||
setError(`Connection failed (${response.status}). Please try again.`);
|
|
||||||
}
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isConnectData(data)) {
|
|
||||||
setConnectData(data);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
setError("Unexpected response from server. Please try again.");
|
|
||||||
setIsLoading(false);
|
|
||||||
} catch (err) {
|
|
||||||
console.error("Connect resume error:", err);
|
|
||||||
setError(
|
|
||||||
"An error occurred while resuming connection. Please try again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
}, [sessionId, isUserLoading, refreshSession]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
resumeConnectFlow();
|
|
||||||
}, [resumeConnectFlow]);
|
|
||||||
|
|
||||||
if (isLoading || isUserLoading) {
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
|
||||||
<div className="text-center">
|
|
||||||
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
|
|
||||||
<p className="text-zinc-400">Resuming connection...</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
|
||||||
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
|
|
||||||
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
|
|
||||||
<svg
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
fill="none"
|
|
||||||
stroke="currentColor"
|
|
||||||
strokeWidth="2"
|
|
||||||
>
|
|
||||||
<circle cx="12" cy="12" r="10" />
|
|
||||||
<line x1="15" y1="9" x2="9" y2="15" />
|
|
||||||
<line x1="9" y1="9" x2="15" y2="15" />
|
|
||||||
</svg>
|
|
||||||
</div>
|
|
||||||
<h1 className="mb-2 text-xl font-semibold text-red-400">
|
|
||||||
Connection Error
|
|
||||||
</h1>
|
|
||||||
<p className="mb-6 text-zinc-400">{error}</p>
|
|
||||||
<button
|
|
||||||
onClick={() => window.close()}
|
|
||||||
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
|
|
||||||
>
|
|
||||||
Close
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (connectData) {
|
|
||||||
return (
|
|
||||||
<ConnectForm
|
|
||||||
client={connectData.client}
|
|
||||||
provider={connectData.provider}
|
|
||||||
scopes={connectData.scopes}
|
|
||||||
credentials={connectData.credentials}
|
|
||||||
connectToken={connectData.connect_token}
|
|
||||||
actionUrl={connectData.action_url}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -22,28 +22,20 @@ export async function GET(request: Request) {
|
|||||||
|
|
||||||
console.debug("Sending message to opener:", message);
|
console.debug("Sending message to opener:", message);
|
||||||
|
|
||||||
// Escape JSON to prevent XSS attacks via </script> injection
|
|
||||||
const safeJson = JSON.stringify(message)
|
|
||||||
.replace(/</g, "\\u003c")
|
|
||||||
.replace(/>/g, "\\u003e");
|
|
||||||
|
|
||||||
// Return a response with the message as JSON and a script to close the window
|
// Return a response with the message as JSON and a script to close the window
|
||||||
return new NextResponse(
|
return new NextResponse(
|
||||||
`<!DOCTYPE html>
|
`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<script>
|
<script>
|
||||||
window.opener.postMessage(${safeJson}, '*');
|
window.opener.postMessage(${JSON.stringify(message)});
|
||||||
window.close();
|
window.close();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>`,
|
</html>
|
||||||
|
`,
|
||||||
{
|
{
|
||||||
headers: {
|
headers: { "Content-Type": "text/html" },
|
||||||
"Content-Type": "text/html",
|
|
||||||
"Content-Security-Policy":
|
|
||||||
"default-src 'none'; script-src 'unsafe-inline'",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,399 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useEffect, useState, useRef, useCallback } from "react";
|
|
||||||
import { useSearchParams } from "next/navigation";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
|
||||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
|
||||||
|
|
||||||
// Module-level flag to prevent duplicate requests across React StrictMode re-renders
|
|
||||||
// This is keyed by session_id to allow different sessions
|
|
||||||
const attemptedSessions = new Set<string>();
|
|
||||||
|
|
||||||
interface ScopeInfo {
|
|
||||||
scope: string;
|
|
||||||
description: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ClientInfo {
|
|
||||||
name: string;
|
|
||||||
logo_url: string | null;
|
|
||||||
privacy_policy_url: string | null;
|
|
||||||
terms_url: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ConsentData {
|
|
||||||
needs_consent: true;
|
|
||||||
consent_token: string;
|
|
||||||
client: ClientInfo;
|
|
||||||
scopes: ScopeInfo[];
|
|
||||||
action_url: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface RedirectData {
|
|
||||||
redirect_url: string;
|
|
||||||
needs_consent: false;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ErrorData {
|
|
||||||
error: string;
|
|
||||||
error_description: string;
|
|
||||||
redirect_url?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResumeResponse = ConsentData | RedirectData | ErrorData;
|
|
||||||
|
|
||||||
function isConsentData(data: ResumeResponse): data is ConsentData {
|
|
||||||
return "needs_consent" in data && data.needs_consent === true;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isRedirectData(data: ResumeResponse): data is RedirectData {
|
|
||||||
return "redirect_url" in data && !("error" in data);
|
|
||||||
}
|
|
||||||
|
|
||||||
function isErrorData(data: ResumeResponse): data is ErrorData {
|
|
||||||
return "error" in data;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* OAuth Consent Form Component
|
|
||||||
*
|
|
||||||
* Renders a proper React component for the consent form instead of dangerouslySetInnerHTML
|
|
||||||
*/
|
|
||||||
function ConsentForm({
|
|
||||||
client,
|
|
||||||
scopes,
|
|
||||||
consentToken,
|
|
||||||
actionUrl,
|
|
||||||
}: {
|
|
||||||
client: ClientInfo;
|
|
||||||
scopes: ScopeInfo[];
|
|
||||||
consentToken: string;
|
|
||||||
actionUrl: string;
|
|
||||||
}) {
|
|
||||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
|
||||||
const backendOrigin = backendUrl
|
|
||||||
? new URL(backendUrl).origin
|
|
||||||
: "http://localhost:8006";
|
|
||||||
|
|
||||||
// Full action URL for form submission
|
|
||||||
const fullActionUrl = `${backendOrigin}${actionUrl}`;
|
|
||||||
|
|
||||||
function handleSubmit() {
|
|
||||||
setIsSubmitting(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800 p-5">
|
|
||||||
<div className="w-full max-w-md rounded-2xl bg-zinc-800 p-8 shadow-2xl">
|
|
||||||
{/* Header */}
|
|
||||||
<div className="mb-6 text-center">
|
|
||||||
<div className="mx-auto mb-4 flex h-16 w-16 items-center justify-center rounded-xl bg-zinc-700">
|
|
||||||
{client.logo_url ? (
|
|
||||||
<img
|
|
||||||
src={client.logo_url}
|
|
||||||
alt={client.name}
|
|
||||||
className="h-12 w-12 rounded-lg"
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<span className="text-3xl text-zinc-400">
|
|
||||||
{client.name.charAt(0).toUpperCase()}
|
|
||||||
</span>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<h1 className="text-xl font-semibold text-zinc-100">
|
|
||||||
Authorize <span className="text-cyan-400">{client.name}</span>
|
|
||||||
</h1>
|
|
||||||
<p className="mt-2 text-sm text-zinc-400">
|
|
||||||
wants to access your AutoGPT account
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Divider */}
|
|
||||||
<div className="my-6 h-px bg-zinc-700" />
|
|
||||||
|
|
||||||
{/* Scopes Section */}
|
|
||||||
<div className="mb-6">
|
|
||||||
<h2 className="mb-4 text-sm font-medium text-zinc-400">
|
|
||||||
This will allow {client.name} to:
|
|
||||||
</h2>
|
|
||||||
<div className="space-y-3">
|
|
||||||
{scopes.map((scope) => (
|
|
||||||
<div
|
|
||||||
key={scope.scope}
|
|
||||||
className="flex items-start gap-3 border-b border-zinc-700 pb-3 last:border-0"
|
|
||||||
>
|
|
||||||
<svg
|
|
||||||
className="mt-0.5 h-5 w-5 flex-shrink-0 text-cyan-400"
|
|
||||||
viewBox="0 0 20 20"
|
|
||||||
fill="currentColor"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
fillRule="evenodd"
|
|
||||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
|
||||||
clipRule="evenodd"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<span className="text-sm leading-relaxed text-zinc-300">
|
|
||||||
{scope.description}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Form */}
|
|
||||||
<form method="POST" action={fullActionUrl} onSubmit={handleSubmit}>
|
|
||||||
<input type="hidden" name="consent_token" value={consentToken} />
|
|
||||||
<div className="flex gap-3">
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
name="authorize"
|
|
||||||
value="false"
|
|
||||||
disabled={isSubmitting}
|
|
||||||
className="flex-1 rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600 disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
>
|
|
||||||
Cancel
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
name="authorize"
|
|
||||||
value="true"
|
|
||||||
disabled={isSubmitting}
|
|
||||||
className="flex-1 rounded-lg bg-cyan-400 px-6 py-3 text-sm font-medium text-slate-900 transition-colors hover:bg-cyan-300 disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
>
|
|
||||||
{isSubmitting ? "Authorizing..." : "Allow"}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
|
|
||||||
{/* Footer Links */}
|
|
||||||
{(client.privacy_policy_url || client.terms_url) && (
|
|
||||||
<div className="mt-6 text-center text-xs text-zinc-500">
|
|
||||||
{client.privacy_policy_url && (
|
|
||||||
<a
|
|
||||||
href={client.privacy_policy_url}
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
className="text-zinc-400 hover:underline"
|
|
||||||
>
|
|
||||||
Privacy Policy
|
|
||||||
</a>
|
|
||||||
)}
|
|
||||||
{client.privacy_policy_url && client.terms_url && (
|
|
||||||
<span className="mx-2">•</span>
|
|
||||||
)}
|
|
||||||
{client.terms_url && (
|
|
||||||
<a
|
|
||||||
href={client.terms_url}
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
className="text-zinc-400 hover:underline"
|
|
||||||
>
|
|
||||||
Terms of Service
|
|
||||||
</a>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* OAuth Resume Page
|
|
||||||
*
|
|
||||||
* This page handles resuming the OAuth authorization flow after a user logs in.
|
|
||||||
* It fetches the consent data from the backend via JSON API and renders the consent form.
|
|
||||||
*/
|
|
||||||
export default function OAuthResumePage() {
|
|
||||||
const searchParams = useSearchParams();
|
|
||||||
const sessionId = searchParams.get("session_id");
|
|
||||||
const { isUserLoading, refreshSession } = useSupabase();
|
|
||||||
|
|
||||||
const [consentData, setConsentData] = useState<ConsentData | null>(null);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
|
||||||
const retryCountRef = useRef(0);
|
|
||||||
const maxRetries = 5;
|
|
||||||
|
|
||||||
const resumeOAuthFlow = useCallback(async () => {
|
|
||||||
// Prevent multiple attempts for the same session (handles React StrictMode)
|
|
||||||
if (!sessionId) {
|
|
||||||
setError(
|
|
||||||
"Missing session ID. Please start the authorization process again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (attemptedSessions.has(sessionId)) {
|
|
||||||
// Already attempted this session, don't retry
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isUserLoading) {
|
|
||||||
return; // Wait for auth state to load
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark this session as attempted IMMEDIATELY to prevent duplicate requests
|
|
||||||
attemptedSessions.add(sessionId);
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Get the access token from server action (which reads cookies properly)
|
|
||||||
let tokenResult = await getWebSocketToken();
|
|
||||||
let accessToken = tokenResult.token;
|
|
||||||
|
|
||||||
// If no token, retry a few times with delays
|
|
||||||
while (!accessToken && retryCountRef.current < maxRetries) {
|
|
||||||
retryCountRef.current += 1;
|
|
||||||
console.log(
|
|
||||||
`Retrying to get access token (attempt ${retryCountRef.current}/${maxRetries})...`,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Try refreshing the session
|
|
||||||
await refreshSession();
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
|
||||||
|
|
||||||
tokenResult = await getWebSocketToken();
|
|
||||||
accessToken = tokenResult.token;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!accessToken) {
|
|
||||||
setError(
|
|
||||||
"Unable to retrieve authentication token. Please log in again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the backend resume endpoint with JSON accept header
|
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_AGPT_SERVER_URL;
|
|
||||||
if (!backendUrl) {
|
|
||||||
setError("Backend URL not configured.");
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the origin from the backend URL
|
|
||||||
let backendOrigin: string;
|
|
||||||
try {
|
|
||||||
const url = new URL(backendUrl);
|
|
||||||
backendOrigin = url.origin;
|
|
||||||
} catch {
|
|
||||||
setError("Invalid backend URL configuration.");
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use Accept: application/json to get JSON response instead of HTML
|
|
||||||
// This solves the CORS/redirect issue by letting us handle redirects client-side
|
|
||||||
const response = await fetch(
|
|
||||||
`${backendOrigin}/oauth/authorize/resume?session_id=${encodeURIComponent(sessionId)}`,
|
|
||||||
{
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
Authorization: `Bearer ${accessToken}`,
|
|
||||||
Accept: "application/json",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
const data: ResumeResponse = await response.json();
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
if (isErrorData(data)) {
|
|
||||||
setError(data.error_description || data.error);
|
|
||||||
} else {
|
|
||||||
setError(
|
|
||||||
`Authorization failed (${response.status}). Please try again.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle redirect response (user already authorized these scopes)
|
|
||||||
if (isRedirectData(data)) {
|
|
||||||
window.location.href = data.redirect_url;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle consent required
|
|
||||||
if (isConsentData(data)) {
|
|
||||||
setConsentData(data);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unexpected response
|
|
||||||
setError("Unexpected response from server. Please try again.");
|
|
||||||
setIsLoading(false);
|
|
||||||
} catch (err) {
|
|
||||||
console.error("OAuth resume error:", err);
|
|
||||||
setError(
|
|
||||||
"An error occurred while resuming authorization. Please try again.",
|
|
||||||
);
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
}, [sessionId, isUserLoading, refreshSession]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
resumeOAuthFlow();
|
|
||||||
}, [resumeOAuthFlow]);
|
|
||||||
|
|
||||||
if (isLoading || isUserLoading) {
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
|
||||||
<div className="text-center">
|
|
||||||
<div className="mx-auto mb-4 h-8 w-8 animate-spin rounded-full border-4 border-zinc-600 border-t-cyan-400"></div>
|
|
||||||
<p className="text-zinc-400">Resuming authorization...</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return (
|
|
||||||
<div className="flex min-h-screen items-center justify-center bg-gradient-to-br from-slate-900 to-slate-800">
|
|
||||||
<div className="mx-auto max-w-md rounded-2xl bg-zinc-800 p-8 text-center shadow-2xl">
|
|
||||||
<div className="mx-auto mb-4 h-16 w-16 text-red-500">
|
|
||||||
<svg
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
fill="none"
|
|
||||||
stroke="currentColor"
|
|
||||||
strokeWidth="2"
|
|
||||||
>
|
|
||||||
<circle cx="12" cy="12" r="10" />
|
|
||||||
<line x1="15" y1="9" x2="9" y2="15" />
|
|
||||||
<line x1="9" y1="9" x2="15" y2="15" />
|
|
||||||
</svg>
|
|
||||||
</div>
|
|
||||||
<h1 className="mb-2 text-xl font-semibold text-red-400">
|
|
||||||
Authorization Error
|
|
||||||
</h1>
|
|
||||||
<p className="mb-6 text-zinc-400">{error}</p>
|
|
||||||
<button
|
|
||||||
onClick={() => window.close()}
|
|
||||||
className="rounded-lg bg-zinc-700 px-6 py-3 text-sm font-medium text-zinc-200 transition-colors hover:bg-zinc-600"
|
|
||||||
>
|
|
||||||
Close
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (consentData) {
|
|
||||||
return (
|
|
||||||
<ConsentForm
|
|
||||||
client={consentData.client}
|
|
||||||
scopes={consentData.scopes}
|
|
||||||
consentToken={consentData.consent_token}
|
|
||||||
actionUrl={consentData.action_url}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -1,24 +1,25 @@
|
|||||||
import { useCallback } from "react";
|
import { useCallback } from "react";
|
||||||
import { useReactFlow } from "@xyflow/react";
|
import { useReactFlow } from "@xyflow/react";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
|
||||||
import { v4 as uuidv4 } from "uuid";
|
import { v4 as uuidv4 } from "uuid";
|
||||||
import { useNodeStore } from "../../../stores/nodeStore";
|
import { useNodeStore } from "../../../stores/nodeStore";
|
||||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||||
import { CustomNode } from "../nodes/CustomNode/CustomNode";
|
import { CustomNode } from "../nodes/CustomNode/CustomNode";
|
||||||
import { CustomEdge } from "../edges/CustomEdge";
|
import { CustomEdge } from "../edges/CustomEdge";
|
||||||
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
|
||||||
interface CopyableData {
|
interface CopyableData {
|
||||||
nodes: CustomNode[];
|
nodes: CustomNode[];
|
||||||
edges: CustomEdge[];
|
edges: CustomEdge[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CLIPBOARD_PREFIX = "autogpt-flow-data:";
|
||||||
|
|
||||||
export function useCopyPaste() {
|
export function useCopyPaste() {
|
||||||
// Only use useReactFlow for viewport (not managed by stores)
|
|
||||||
const { getViewport } = useReactFlow();
|
const { getViewport } = useReactFlow();
|
||||||
|
const { toast } = useToast();
|
||||||
|
|
||||||
const handleCopyPaste = useCallback(
|
const handleCopyPaste = useCallback(
|
||||||
(event: KeyboardEvent) => {
|
(event: KeyboardEvent) => {
|
||||||
// Prevent copy/paste if any modal is open or if the focus is on an input element
|
|
||||||
const activeElement = document.activeElement;
|
const activeElement = document.activeElement;
|
||||||
const isInputField =
|
const isInputField =
|
||||||
activeElement?.tagName === "INPUT" ||
|
activeElement?.tagName === "INPUT" ||
|
||||||
@@ -28,7 +29,6 @@ export function useCopyPaste() {
|
|||||||
if (isInputField) return;
|
if (isInputField) return;
|
||||||
|
|
||||||
if (event.ctrlKey || event.metaKey) {
|
if (event.ctrlKey || event.metaKey) {
|
||||||
// COPY: Ctrl+C or Cmd+C
|
|
||||||
if (event.key === "c" || event.key === "C") {
|
if (event.key === "c" || event.key === "C") {
|
||||||
const { nodes } = useNodeStore.getState();
|
const { nodes } = useNodeStore.getState();
|
||||||
const { edges } = useEdgeStore.getState();
|
const { edges } = useEdgeStore.getState();
|
||||||
@@ -53,81 +53,102 @@ export function useCopyPaste() {
|
|||||||
edges: selectedEdges,
|
edges: selectedEdges,
|
||||||
};
|
};
|
||||||
|
|
||||||
storage.set(Key.COPIED_FLOW_DATA, JSON.stringify(copiedData));
|
const clipboardText = `${CLIPBOARD_PREFIX}${JSON.stringify(copiedData)}`;
|
||||||
|
navigator.clipboard
|
||||||
|
.writeText(clipboardText)
|
||||||
|
.then(() => {
|
||||||
|
toast({
|
||||||
|
title: "Copied successfully",
|
||||||
|
description: `${selectedNodes.length} node(s) copied to clipboard`,
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error("Failed to copy to clipboard:", error);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// PASTE: Ctrl+V or Cmd+V
|
|
||||||
if (event.key === "v" || event.key === "V") {
|
if (event.key === "v" || event.key === "V") {
|
||||||
const copiedDataString = storage.get(Key.COPIED_FLOW_DATA);
|
navigator.clipboard
|
||||||
if (copiedDataString) {
|
.readText()
|
||||||
const copiedData = JSON.parse(copiedDataString) as CopyableData;
|
.then((clipboardText) => {
|
||||||
const oldToNewIdMap: Record<string, string> = {};
|
if (!clipboardText.startsWith(CLIPBOARD_PREFIX)) {
|
||||||
|
return; // Not our data, ignore
|
||||||
|
}
|
||||||
|
|
||||||
// Get fresh viewport values at paste time to ensure correct positioning
|
const jsonString = clipboardText.slice(CLIPBOARD_PREFIX.length);
|
||||||
const { x, y, zoom } = getViewport();
|
const copiedData = JSON.parse(jsonString) as CopyableData;
|
||||||
const viewportCenter = {
|
const oldToNewIdMap: Record<string, string> = {};
|
||||||
x: (window.innerWidth / 2 - x) / zoom,
|
|
||||||
y: (window.innerHeight / 2 - y) / zoom,
|
|
||||||
};
|
|
||||||
|
|
||||||
let minX = Infinity,
|
const { x, y, zoom } = getViewport();
|
||||||
minY = Infinity,
|
const viewportCenter = {
|
||||||
maxX = -Infinity,
|
x: (window.innerWidth / 2 - x) / zoom,
|
||||||
maxY = -Infinity;
|
y: (window.innerHeight / 2 - y) / zoom,
|
||||||
copiedData.nodes.forEach((node) => {
|
|
||||||
minX = Math.min(minX, node.position.x);
|
|
||||||
minY = Math.min(minY, node.position.y);
|
|
||||||
maxX = Math.max(maxX, node.position.x);
|
|
||||||
maxY = Math.max(maxY, node.position.y);
|
|
||||||
});
|
|
||||||
|
|
||||||
const offsetX = viewportCenter.x - (minX + maxX) / 2;
|
|
||||||
const offsetY = viewportCenter.y - (minY + maxY) / 2;
|
|
||||||
|
|
||||||
// Deselect existing nodes first
|
|
||||||
useNodeStore.setState((state) => ({
|
|
||||||
nodes: state.nodes.map((node) => ({ ...node, selected: false })),
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Create and add new nodes with UNIQUE IDs using UUID
|
|
||||||
copiedData.nodes.forEach((node) => {
|
|
||||||
const newNodeId = uuidv4();
|
|
||||||
oldToNewIdMap[node.id] = newNodeId;
|
|
||||||
|
|
||||||
const newNode: CustomNode = {
|
|
||||||
...node,
|
|
||||||
id: newNodeId,
|
|
||||||
selected: true,
|
|
||||||
position: {
|
|
||||||
x: node.position.x + offsetX,
|
|
||||||
y: node.position.y + offsetY,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
useNodeStore.getState().addNode(newNode);
|
let minX = Infinity,
|
||||||
});
|
minY = Infinity,
|
||||||
|
maxX = -Infinity,
|
||||||
// Add edges with updated source/target IDs
|
maxY = -Infinity;
|
||||||
const { addEdge } = useEdgeStore.getState();
|
copiedData.nodes.forEach((node) => {
|
||||||
copiedData.edges.forEach((edge) => {
|
minX = Math.min(minX, node.position.x);
|
||||||
const newSourceId = oldToNewIdMap[edge.source] ?? edge.source;
|
minY = Math.min(minY, node.position.y);
|
||||||
const newTargetId = oldToNewIdMap[edge.target] ?? edge.target;
|
maxX = Math.max(maxX, node.position.x);
|
||||||
|
maxY = Math.max(maxY, node.position.y);
|
||||||
addEdge({
|
|
||||||
source: newSourceId,
|
|
||||||
target: newTargetId,
|
|
||||||
sourceHandle: edge.sourceHandle ?? "",
|
|
||||||
targetHandle: edge.targetHandle ?? "",
|
|
||||||
data: {
|
|
||||||
...edge.data,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const offsetX = viewportCenter.x - (minX + maxX) / 2;
|
||||||
|
const offsetY = viewportCenter.y - (minY + maxY) / 2;
|
||||||
|
|
||||||
|
// Deselect existing nodes first
|
||||||
|
useNodeStore.setState((state) => ({
|
||||||
|
nodes: state.nodes.map((node) => ({
|
||||||
|
...node,
|
||||||
|
selected: false,
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Create and add new nodes with UNIQUE IDs using UUID
|
||||||
|
copiedData.nodes.forEach((node) => {
|
||||||
|
const newNodeId = uuidv4();
|
||||||
|
oldToNewIdMap[node.id] = newNodeId;
|
||||||
|
|
||||||
|
const newNode: CustomNode = {
|
||||||
|
...node,
|
||||||
|
id: newNodeId,
|
||||||
|
selected: true,
|
||||||
|
position: {
|
||||||
|
x: node.position.x + offsetX,
|
||||||
|
y: node.position.y + offsetY,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
useNodeStore.getState().addNode(newNode);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add edges with updated source/target IDs
|
||||||
|
const { addEdge } = useEdgeStore.getState();
|
||||||
|
copiedData.edges.forEach((edge) => {
|
||||||
|
const newSourceId = oldToNewIdMap[edge.source] ?? edge.source;
|
||||||
|
const newTargetId = oldToNewIdMap[edge.target] ?? edge.target;
|
||||||
|
|
||||||
|
addEdge({
|
||||||
|
source: newSourceId,
|
||||||
|
target: newTargetId,
|
||||||
|
sourceHandle: edge.sourceHandle ?? "",
|
||||||
|
targetHandle: edge.targetHandle ?? "",
|
||||||
|
data: {
|
||||||
|
...edge.data,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error("Failed to read from clipboard:", error);
|
||||||
});
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[getViewport],
|
[getViewport, toast],
|
||||||
);
|
);
|
||||||
|
|
||||||
return handleCopyPaste;
|
return handleCopyPaste;
|
||||||
|
|||||||
@@ -42,11 +42,12 @@ export const useFlow = () => {
|
|||||||
const setBlockMenuOpen = useControlPanelStore(
|
const setBlockMenuOpen = useControlPanelStore(
|
||||||
useShallow((state) => state.setBlockMenuOpen),
|
useShallow((state) => state.setBlockMenuOpen),
|
||||||
);
|
);
|
||||||
const [{ flowID, flowVersion, flowExecutionID }] = useQueryStates({
|
const [{ flowID, flowVersion, flowExecutionID }, setQueryStates] =
|
||||||
flowID: parseAsString,
|
useQueryStates({
|
||||||
flowVersion: parseAsInteger,
|
flowID: parseAsString,
|
||||||
flowExecutionID: parseAsString,
|
flowVersion: parseAsInteger,
|
||||||
});
|
flowExecutionID: parseAsString,
|
||||||
|
});
|
||||||
|
|
||||||
const { data: executionDetails } = useGetV1GetExecutionDetails(
|
const { data: executionDetails } = useGetV1GetExecutionDetails(
|
||||||
flowID || "",
|
flowID || "",
|
||||||
@@ -102,6 +103,9 @@ export const useFlow = () => {
|
|||||||
// load graph schemas
|
// load graph schemas
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (graph) {
|
if (graph) {
|
||||||
|
setQueryStates({
|
||||||
|
flowVersion: graph.version ?? 1,
|
||||||
|
});
|
||||||
setGraphSchemas(
|
setGraphSchemas(
|
||||||
graph.input_schema as Record<string, any> | null,
|
graph.input_schema as Record<string, any> | null,
|
||||||
graph.credentials_input_schema as Record<string, any> | null,
|
graph.credentials_input_schema as Record<string, any> | null,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
import { useGetV2BuilderSearchInfinite } from "@/app/api/__generated__/endpoints/store/store";
|
import { useGetV2BuilderSearchInfinite } from "@/app/api/__generated__/endpoints/store/store";
|
||||||
import { SearchResponse } from "@/app/api/__generated__/models/searchResponse";
|
import { SearchResponse } from "@/app/api/__generated__/models/searchResponse";
|
||||||
import { useState } from "react";
|
import { useCallback, useEffect, useState } from "react";
|
||||||
import { useAddAgentToBuilder } from "../hooks/useAddAgentToBuilder";
|
import { useAddAgentToBuilder } from "../hooks/useAddAgentToBuilder";
|
||||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
import { getV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store";
|
import { getV2GetSpecificAgent } from "@/app/api/__generated__/endpoints/store/store";
|
||||||
@@ -9,16 +9,27 @@ import {
|
|||||||
getGetV2ListLibraryAgentsQueryKey,
|
getGetV2ListLibraryAgentsQueryKey,
|
||||||
usePostV2AddMarketplaceAgent,
|
usePostV2AddMarketplaceAgent,
|
||||||
} from "@/app/api/__generated__/endpoints/library/library";
|
} from "@/app/api/__generated__/endpoints/library/library";
|
||||||
import { getGetV2GetBuilderItemCountsQueryKey } from "@/app/api/__generated__/endpoints/default/default";
|
import {
|
||||||
|
getGetV2GetBuilderItemCountsQueryKey,
|
||||||
|
getGetV2GetBuilderSuggestionsQueryKey,
|
||||||
|
} from "@/app/api/__generated__/endpoints/default/default";
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
|
|
||||||
export const useBlockMenuSearch = () => {
|
export const useBlockMenuSearch = () => {
|
||||||
const { searchQuery } = useBlockMenuStore();
|
const { searchQuery, searchId, setSearchId } = useBlockMenuStore();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
const { addAgentToBuilder, addLibraryAgentToBuilder } =
|
const { addAgentToBuilder, addLibraryAgentToBuilder } =
|
||||||
useAddAgentToBuilder();
|
useAddAgentToBuilder();
|
||||||
|
const queryClient = getQueryClient();
|
||||||
|
|
||||||
|
const resetSearchSession = useCallback(() => {
|
||||||
|
setSearchId(undefined);
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetBuilderSuggestionsQueryKey(),
|
||||||
|
});
|
||||||
|
}, [queryClient, setSearchId]);
|
||||||
|
|
||||||
const [addingLibraryAgentId, setAddingLibraryAgentId] = useState<
|
const [addingLibraryAgentId, setAddingLibraryAgentId] = useState<
|
||||||
string | null
|
string | null
|
||||||
@@ -38,13 +49,19 @@ export const useBlockMenuSearch = () => {
|
|||||||
page: 1,
|
page: 1,
|
||||||
page_size: 8,
|
page_size: 8,
|
||||||
search_query: searchQuery,
|
search_query: searchQuery,
|
||||||
|
search_id: searchId,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
query: {
|
query: {
|
||||||
getNextPageParam: (lastPage, allPages) => {
|
getNextPageParam: (lastPage) => {
|
||||||
const pagination = lastPage.data as SearchResponse;
|
const response = lastPage.data as SearchResponse;
|
||||||
const isMore = pagination.more_pages;
|
const { pagination } = response;
|
||||||
return isMore ? allPages.length + 1 : undefined;
|
if (!pagination) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { current_page, total_pages } = pagination;
|
||||||
|
return current_page < total_pages ? current_page + 1 : undefined;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -53,7 +70,6 @@ export const useBlockMenuSearch = () => {
|
|||||||
const { mutateAsync: addMarketplaceAgent } = usePostV2AddMarketplaceAgent({
|
const { mutateAsync: addMarketplaceAgent } = usePostV2AddMarketplaceAgent({
|
||||||
mutation: {
|
mutation: {
|
||||||
onSuccess: () => {
|
onSuccess: () => {
|
||||||
const queryClient = getQueryClient();
|
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListLibraryAgentsQueryKey(),
|
queryKey: getGetV2ListLibraryAgentsQueryKey(),
|
||||||
});
|
});
|
||||||
@@ -75,6 +91,24 @@ export const useBlockMenuSearch = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!searchData?.pages?.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const latestPage = searchData.pages[searchData.pages.length - 1];
|
||||||
|
const response = latestPage?.data as SearchResponse;
|
||||||
|
if (response?.search_id && response.search_id !== searchId) {
|
||||||
|
setSearchId(response.search_id);
|
||||||
|
}
|
||||||
|
}, [searchData, searchId, setSearchId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (searchId && !searchQuery) {
|
||||||
|
resetSearchSession();
|
||||||
|
}
|
||||||
|
}, [resetSearchSession, searchId, searchQuery]);
|
||||||
|
|
||||||
const allSearchData =
|
const allSearchData =
|
||||||
searchData?.pages?.flatMap((page) => {
|
searchData?.pages?.flatMap((page) => {
|
||||||
const response = page.data as SearchResponse;
|
const response = page.data as SearchResponse;
|
||||||
|
|||||||
@@ -1,30 +1,32 @@
|
|||||||
import { debounce } from "lodash";
|
import { debounce } from "lodash";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
|
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||||
|
import { getGetV2GetBuilderSuggestionsQueryKey } from "@/app/api/__generated__/endpoints/default/default";
|
||||||
|
|
||||||
const SEARCH_DEBOUNCE_MS = 300;
|
const SEARCH_DEBOUNCE_MS = 300;
|
||||||
|
|
||||||
export const useBlockMenuSearchBar = () => {
|
export const useBlockMenuSearchBar = () => {
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
const [localQuery, setLocalQuery] = useState("");
|
const [localQuery, setLocalQuery] = useState("");
|
||||||
const { setSearchQuery, setSearchId, searchId, searchQuery } =
|
const { setSearchQuery, setSearchId, searchQuery } = useBlockMenuStore();
|
||||||
useBlockMenuStore();
|
const queryClient = getQueryClient();
|
||||||
|
|
||||||
const searchIdRef = useRef(searchId);
|
const clearSearchSession = useCallback(() => {
|
||||||
useEffect(() => {
|
setSearchId(undefined);
|
||||||
searchIdRef.current = searchId;
|
queryClient.invalidateQueries({
|
||||||
}, [searchId]);
|
queryKey: getGetV2GetBuilderSuggestionsQueryKey(),
|
||||||
|
});
|
||||||
|
}, [queryClient, setSearchId]);
|
||||||
|
|
||||||
const debouncedSetSearchQuery = useCallback(
|
const debouncedSetSearchQuery = useCallback(
|
||||||
debounce((value: string) => {
|
debounce((value: string) => {
|
||||||
setSearchQuery(value);
|
setSearchQuery(value);
|
||||||
if (value.length === 0) {
|
if (value.length === 0) {
|
||||||
setSearchId(undefined);
|
clearSearchSession();
|
||||||
} else if (!searchIdRef.current) {
|
|
||||||
setSearchId(crypto.randomUUID());
|
|
||||||
}
|
}
|
||||||
}, SEARCH_DEBOUNCE_MS),
|
}, SEARCH_DEBOUNCE_MS),
|
||||||
[setSearchQuery, setSearchId],
|
[clearSearchSession, setSearchQuery],
|
||||||
);
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -36,13 +38,13 @@ export const useBlockMenuSearchBar = () => {
|
|||||||
const handleClear = () => {
|
const handleClear = () => {
|
||||||
setLocalQuery("");
|
setLocalQuery("");
|
||||||
setSearchQuery("");
|
setSearchQuery("");
|
||||||
setSearchId(undefined);
|
clearSearchSession();
|
||||||
debouncedSetSearchQuery.cancel();
|
debouncedSetSearchQuery.cancel();
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setLocalQuery(searchQuery);
|
setLocalQuery(searchQuery);
|
||||||
}, []);
|
}, [searchQuery]);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
handleClear,
|
handleClear,
|
||||||
|
|||||||
@@ -0,0 +1,109 @@
|
|||||||
|
import React, { useEffect, useRef, useState } from "react";
|
||||||
|
import { ArrowLeftIcon, ArrowRightIcon } from "@phosphor-icons/react";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface HorizontalScrollAreaProps {
|
||||||
|
children: React.ReactNode;
|
||||||
|
wrapperClassName?: string;
|
||||||
|
scrollContainerClassName?: string;
|
||||||
|
scrollAmount?: number;
|
||||||
|
dependencyList?: React.DependencyList;
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultDependencies: React.DependencyList = [];
|
||||||
|
const baseScrollClasses =
|
||||||
|
"flex gap-2 overflow-x-auto px-8 [scrollbar-width:none] [-ms-overflow-style:'none'] [&::-webkit-scrollbar]:hidden";
|
||||||
|
|
||||||
|
export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
|
||||||
|
children,
|
||||||
|
wrapperClassName,
|
||||||
|
scrollContainerClassName,
|
||||||
|
scrollAmount = 300,
|
||||||
|
dependencyList = defaultDependencies,
|
||||||
|
}) => {
|
||||||
|
const scrollRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
const [canScrollLeft, setCanScrollLeft] = useState(false);
|
||||||
|
const [canScrollRight, setCanScrollRight] = useState(false);
|
||||||
|
|
||||||
|
const scrollByDelta = (delta: number) => {
|
||||||
|
if (!scrollRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
scrollRef.current.scrollBy({ left: delta, behavior: "smooth" });
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateScrollState = () => {
|
||||||
|
const element = scrollRef.current;
|
||||||
|
if (!element) {
|
||||||
|
setCanScrollLeft(false);
|
||||||
|
setCanScrollRight(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setCanScrollLeft(element.scrollLeft > 0);
|
||||||
|
setCanScrollRight(
|
||||||
|
Math.ceil(element.scrollLeft + element.clientWidth) < element.scrollWidth,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
updateScrollState();
|
||||||
|
const element = scrollRef.current;
|
||||||
|
if (!element) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const handleScroll = () => updateScrollState();
|
||||||
|
element.addEventListener("scroll", handleScroll);
|
||||||
|
window.addEventListener("resize", handleScroll);
|
||||||
|
return () => {
|
||||||
|
element.removeEventListener("scroll", handleScroll);
|
||||||
|
window.removeEventListener("resize", handleScroll);
|
||||||
|
};
|
||||||
|
}, dependencyList);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={wrapperClassName}>
|
||||||
|
<div className="group relative">
|
||||||
|
<div
|
||||||
|
ref={scrollRef}
|
||||||
|
className={cn(baseScrollClasses, scrollContainerClassName)}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</div>
|
||||||
|
{canScrollLeft && (
|
||||||
|
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
|
||||||
|
)}
|
||||||
|
{canScrollRight && (
|
||||||
|
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
|
||||||
|
)}
|
||||||
|
{canScrollLeft && (
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-label="Scroll left"
|
||||||
|
className="pointer-events-none absolute left-2 top-5 -translate-y-1/2 opacity-0 transition-opacity duration-200 group-hover:pointer-events-auto group-hover:opacity-100"
|
||||||
|
onClick={() => scrollByDelta(-scrollAmount)}
|
||||||
|
>
|
||||||
|
<ArrowLeftIcon
|
||||||
|
size={28}
|
||||||
|
className="rounded-full bg-zinc-700 p-1 text-white drop-shadow"
|
||||||
|
weight="light"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
{canScrollRight && (
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-label="Scroll right"
|
||||||
|
className="pointer-events-none absolute right-2 top-5 -translate-y-1/2 opacity-0 transition-opacity duration-200 group-hover:pointer-events-auto group-hover:opacity-100"
|
||||||
|
onClick={() => scrollByDelta(scrollAmount)}
|
||||||
|
>
|
||||||
|
<ArrowRightIcon
|
||||||
|
size={28}
|
||||||
|
className="rounded-full bg-zinc-700 p-1 text-white drop-shadow"
|
||||||
|
weight="light"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
@@ -6,10 +6,15 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
|||||||
import { blockMenuContainerStyle } from "../style";
|
import { blockMenuContainerStyle } from "../style";
|
||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
import { DefaultStateType } from "../types";
|
import { DefaultStateType } from "../types";
|
||||||
|
import { SearchHistoryChip } from "../SearchHistoryChip";
|
||||||
|
import { HorizontalScroll } from "../HorizontalScroll";
|
||||||
|
|
||||||
export const SuggestionContent = () => {
|
export const SuggestionContent = () => {
|
||||||
const { setIntegration, setDefaultState } = useBlockMenuStore();
|
const { setIntegration, setDefaultState, setSearchQuery, setSearchId } =
|
||||||
|
useBlockMenuStore();
|
||||||
const { data, isLoading, isError, error, refetch } = useSuggestionContent();
|
const { data, isLoading, isError, error, refetch } = useSuggestionContent();
|
||||||
|
const suggestions = data?.suggestions;
|
||||||
|
const hasRecentSearches = (suggestions?.recent_searches?.length ?? 0) > 0;
|
||||||
|
|
||||||
if (isError) {
|
if (isError) {
|
||||||
return (
|
return (
|
||||||
@@ -29,11 +34,45 @@ export const SuggestionContent = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const suggestions = data?.suggestions;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={blockMenuContainerStyle}>
|
<div className={blockMenuContainerStyle}>
|
||||||
<div className="w-full space-y-6 pb-4">
|
<div className="w-full space-y-6 pb-4">
|
||||||
|
{/* Recent searches */}
|
||||||
|
{hasRecentSearches && (
|
||||||
|
<div className="space-y-2.5 px-4">
|
||||||
|
<p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800">
|
||||||
|
Recent searches
|
||||||
|
</p>
|
||||||
|
<HorizontalScroll
|
||||||
|
wrapperClassName="-mx-8"
|
||||||
|
scrollContainerClassName="flex gap-2 overflow-x-auto px-8 [scrollbar-width:none] [-ms-overflow-style:'none'] [&::-webkit-scrollbar]:hidden"
|
||||||
|
dependencyList={[
|
||||||
|
suggestions?.recent_searches?.length ?? 0,
|
||||||
|
isLoading,
|
||||||
|
]}
|
||||||
|
>
|
||||||
|
{!isLoading && suggestions
|
||||||
|
? suggestions.recent_searches.map((entry, index) => (
|
||||||
|
<SearchHistoryChip
|
||||||
|
key={entry.search_id || `${entry.search_query}-${index}`}
|
||||||
|
content={entry.search_query || "Untitled search"}
|
||||||
|
onClick={() => {
|
||||||
|
setSearchQuery(entry.search_query || "");
|
||||||
|
setSearchId(entry.search_id || undefined);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
: Array(3)
|
||||||
|
.fill(0)
|
||||||
|
.map((_, index) => (
|
||||||
|
<SearchHistoryChip.Skeleton
|
||||||
|
key={`recent-search-skeleton-${index}`}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</HorizontalScroll>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Integrations */}
|
{/* Integrations */}
|
||||||
<div className="space-y-2.5 px-4">
|
<div className="space-y-2.5 px-4">
|
||||||
<p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800">
|
<p className="font-sans text-sm font-medium leading-[1.375rem] text-zinc-800">
|
||||||
|
|||||||
@@ -24,11 +24,13 @@ import { useNewAgentLibraryView } from "./useNewAgentLibraryView";
|
|||||||
|
|
||||||
export function NewAgentLibraryView() {
|
export function NewAgentLibraryView() {
|
||||||
const {
|
const {
|
||||||
agent,
|
|
||||||
hasAnyItems,
|
|
||||||
ready,
|
|
||||||
error,
|
|
||||||
agentId,
|
agentId,
|
||||||
|
agent,
|
||||||
|
ready,
|
||||||
|
activeTemplate,
|
||||||
|
isTemplateLoading,
|
||||||
|
error,
|
||||||
|
hasAnyItems,
|
||||||
activeItem,
|
activeItem,
|
||||||
sidebarLoading,
|
sidebarLoading,
|
||||||
activeTab,
|
activeTab,
|
||||||
@@ -36,6 +38,9 @@ export function NewAgentLibraryView() {
|
|||||||
handleSelectRun,
|
handleSelectRun,
|
||||||
handleCountsChange,
|
handleCountsChange,
|
||||||
handleClearSelectedRun,
|
handleClearSelectedRun,
|
||||||
|
onRunInitiated,
|
||||||
|
onTriggerSetup,
|
||||||
|
onScheduleCreated,
|
||||||
} = useNewAgentLibraryView();
|
} = useNewAgentLibraryView();
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
@@ -65,14 +70,19 @@ export function NewAgentLibraryView() {
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex min-h-0 flex-1">
|
<div className="flex min-h-0 flex-1">
|
||||||
<EmptyTasks agent={agent} />
|
<EmptyTasks
|
||||||
|
agent={agent}
|
||||||
|
onRun={onRunInitiated}
|
||||||
|
onTriggerSetup={onTriggerSetup}
|
||||||
|
onScheduleCreated={onScheduleCreated}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="ml-4 grid h-full grid-cols-1 gap-0 pt-3 md:gap-4 lg:grid-cols-[25%_70%]">
|
<div className="mx-4 grid h-full grid-cols-1 gap-0 pt-3 md:ml-4 md:mr-0 md:gap-4 lg:grid-cols-[25%_70%]">
|
||||||
<SectionWrap className="mb-3 block">
|
<SectionWrap className="mb-3 block">
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -82,16 +92,21 @@ export function NewAgentLibraryView() {
|
|||||||
>
|
>
|
||||||
<RunAgentModal
|
<RunAgentModal
|
||||||
triggerSlot={
|
triggerSlot={
|
||||||
<Button variant="primary" size="large" className="w-full">
|
<Button
|
||||||
|
variant="primary"
|
||||||
|
size="large"
|
||||||
|
className="w-full"
|
||||||
|
disabled={isTemplateLoading && activeTab === "templates"}
|
||||||
|
>
|
||||||
<PlusIcon size={20} /> New task
|
<PlusIcon size={20} /> New task
|
||||||
</Button>
|
</Button>
|
||||||
}
|
}
|
||||||
agent={agent}
|
agent={agent}
|
||||||
agentId={agent.id.toString()}
|
onRunCreated={onRunInitiated}
|
||||||
onRunCreated={(execution) => handleSelectRun(execution.id, "runs")}
|
onScheduleCreated={onScheduleCreated}
|
||||||
onScheduleCreated={(schedule) =>
|
onTriggerSetup={onTriggerSetup}
|
||||||
handleSelectRun(schedule.id, "scheduled")
|
initialInputValues={activeTemplate?.inputs}
|
||||||
}
|
initialInputCredentials={activeTemplate?.credentials}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -151,7 +166,12 @@ export function NewAgentLibraryView() {
|
|||||||
</SelectedViewLayout>
|
</SelectedViewLayout>
|
||||||
) : (
|
) : (
|
||||||
<SelectedViewLayout agentName={agent.name} agentId={agent.id}>
|
<SelectedViewLayout agentName={agent.name} agentId={agent.id}>
|
||||||
<EmptyTasks agent={agent} />
|
<EmptyTasks
|
||||||
|
agent={agent}
|
||||||
|
onRun={onRunInitiated}
|
||||||
|
onTriggerSetup={onTriggerSetup}
|
||||||
|
onScheduleCreated={onScheduleCreated}
|
||||||
|
/>
|
||||||
</SelectedViewLayout>
|
</SelectedViewLayout>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
import type {
|
||||||
|
BlockIOSubSchema,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import { CredentialsInput } from "../CredentialsInputs/CredentialsInputs";
|
import { CredentialsInput } from "../CredentialsInputs/CredentialsInputs";
|
||||||
import {
|
import {
|
||||||
getAgentCredentialsFields,
|
getAgentCredentialsFields,
|
||||||
@@ -20,13 +23,21 @@ export function AgentInputsReadOnly({
|
|||||||
inputs,
|
inputs,
|
||||||
credentialInputs,
|
credentialInputs,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const fields = getAgentInputFields(agent);
|
const inputFields = getAgentInputFields(agent);
|
||||||
const credentialFields = getAgentCredentialsFields(agent);
|
const credentialFieldEntries = Object.entries(
|
||||||
const inputEntries = Object.entries(fields);
|
getAgentCredentialsFields(agent),
|
||||||
const credentialEntries = Object.entries(credentialFields);
|
);
|
||||||
|
|
||||||
const hasInputs = inputs && inputEntries.length > 0;
|
// Take actual input entries as leading; augment with schema from input fields.
|
||||||
const hasCredentials = credentialInputs && credentialEntries.length > 0;
|
// TODO: ensure consistent ordering.
|
||||||
|
const inputEntries =
|
||||||
|
inputs &&
|
||||||
|
Object.entries(inputs).map<[string, [BlockIOSubSchema | undefined, any]]>(
|
||||||
|
([k, v]) => [k, [inputFields[k], v]],
|
||||||
|
);
|
||||||
|
|
||||||
|
const hasInputs = inputEntries && inputEntries.length > 0;
|
||||||
|
const hasCredentials = credentialInputs && credentialFieldEntries.length > 0;
|
||||||
|
|
||||||
if (!hasInputs && !hasCredentials) {
|
if (!hasInputs && !hasCredentials) {
|
||||||
return <div className="text-neutral-600">No input for this run.</div>;
|
return <div className="text-neutral-600">No input for this run.</div>;
|
||||||
@@ -37,11 +48,13 @@ export function AgentInputsReadOnly({
|
|||||||
{/* Regular inputs */}
|
{/* Regular inputs */}
|
||||||
{hasInputs && (
|
{hasInputs && (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
{inputEntries.map(([key, sub]) => (
|
{inputEntries.map(([key, [schema, value]]) => (
|
||||||
<div key={key} className="flex flex-col gap-1.5">
|
<div key={key} className="flex flex-col gap-1.5">
|
||||||
<label className="text-sm font-medium">{sub?.title || key}</label>
|
<label className="text-sm font-medium">
|
||||||
|
{schema?.title || key}
|
||||||
|
</label>
|
||||||
<p className="whitespace-pre-wrap break-words text-sm text-neutral-700">
|
<p className="whitespace-pre-wrap break-words text-sm text-neutral-700">
|
||||||
{renderValue((inputs as Record<string, any>)[key])}
|
{renderValue(value)}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
@@ -52,7 +65,7 @@ export function AgentInputsReadOnly({
|
|||||||
{hasCredentials && (
|
{hasCredentials && (
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
{hasInputs && <div className="border-t border-neutral-200 pt-4" />}
|
{hasInputs && <div className="border-t border-neutral-200 pt-4" />}
|
||||||
{credentialEntries.map(([key, inputSubSchema]) => {
|
{credentialFieldEntries.map(([key, inputSubSchema]) => {
|
||||||
const credential = credentialInputs![key];
|
const credential = credentialInputs![key];
|
||||||
if (!credential) return null;
|
if (!credential) return null;
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ export function getCredentialTypeDisplayName(type: string): string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function getAgentInputFields(agent: LibraryAgent): Record<string, any> {
|
export function getAgentInputFields(agent: LibraryAgent): Record<string, any> {
|
||||||
const schema = agent.input_schema as unknown as {
|
const schema = (agent.trigger_setup_info?.config_schema ??
|
||||||
|
agent.input_schema) as unknown as {
|
||||||
properties?: Record<string, any>;
|
properties?: Record<string, any>;
|
||||||
} | null;
|
} | null;
|
||||||
if (!schema || !schema.properties) return {};
|
if (!schema || !schema.properties) return {};
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||||
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
|
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import {
|
import {
|
||||||
Tooltip,
|
Tooltip,
|
||||||
@@ -22,16 +23,20 @@ import { useAgentRunModal } from "./useAgentRunModal";
|
|||||||
interface Props {
|
interface Props {
|
||||||
triggerSlot: React.ReactNode;
|
triggerSlot: React.ReactNode;
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
agentId: string;
|
initialInputValues?: Record<string, any>;
|
||||||
agentVersion?: number;
|
initialInputCredentials?: Record<string, any>;
|
||||||
onRunCreated?: (execution: GraphExecutionMeta) => void;
|
onRunCreated?: (execution: GraphExecutionMeta) => void;
|
||||||
|
onTriggerSetup?: (preset: LibraryAgentPreset) => void;
|
||||||
onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void;
|
onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function RunAgentModal({
|
export function RunAgentModal({
|
||||||
triggerSlot,
|
triggerSlot,
|
||||||
agent,
|
agent,
|
||||||
|
initialInputValues,
|
||||||
|
initialInputCredentials,
|
||||||
onRunCreated,
|
onRunCreated,
|
||||||
|
onTriggerSetup,
|
||||||
onScheduleCreated,
|
onScheduleCreated,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const {
|
const {
|
||||||
@@ -71,6 +76,9 @@ export function RunAgentModal({
|
|||||||
handleRun,
|
handleRun,
|
||||||
} = useAgentRunModal(agent, {
|
} = useAgentRunModal(agent, {
|
||||||
onRun: onRunCreated,
|
onRun: onRunCreated,
|
||||||
|
onSetupTrigger: onTriggerSetup,
|
||||||
|
initialInputValues,
|
||||||
|
initialInputCredentials,
|
||||||
});
|
});
|
||||||
|
|
||||||
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
||||||
@@ -79,6 +87,8 @@ export function RunAgentModal({
|
|||||||
Object.keys(agentInputFields || {}).length > 0 ||
|
Object.keys(agentInputFields || {}).length > 0 ||
|
||||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||||
|
|
||||||
|
const isTriggerRunType = defaultRunType.includes("trigger");
|
||||||
|
|
||||||
function handleInputChange(key: string, value: string) {
|
function handleInputChange(key: string, value: string) {
|
||||||
setInputValues((prev) => ({
|
setInputValues((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
@@ -153,7 +163,7 @@ export function RunAgentModal({
|
|||||||
|
|
||||||
<Dialog.Footer className="mt-6 bg-white pt-4">
|
<Dialog.Footer className="mt-6 bg-white pt-4">
|
||||||
<div className="flex items-center justify-end gap-3">
|
<div className="flex items-center justify-end gap-3">
|
||||||
{!allRequiredInputsAreSet ? (
|
{isTriggerRunType ? null : !allRequiredInputsAreSet ? (
|
||||||
<TooltipProvider>
|
<TooltipProvider>
|
||||||
<Tooltip>
|
<Tooltip>
|
||||||
<TooltipTrigger asChild>
|
<TooltipTrigger asChild>
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ export function ModalRunSection() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
{defaultRunType === "automatic-trigger" ? (
|
{defaultRunType === "automatic-trigger" ||
|
||||||
|
defaultRunType === "manual-trigger" ? (
|
||||||
<ModalSection
|
<ModalSection
|
||||||
title="Task Trigger"
|
title="Task Trigger"
|
||||||
subtitle="Set up a trigger for the agent to run this task automatically"
|
subtitle="Set up a trigger for the agent to run this task automatically"
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ export function RunActions({
|
|||||||
disabled={!isRunReady || isExecuting || isSettingUpTrigger}
|
disabled={!isRunReady || isExecuting || isSettingUpTrigger}
|
||||||
loading={isExecuting || isSettingUpTrigger}
|
loading={isExecuting || isSettingUpTrigger}
|
||||||
>
|
>
|
||||||
{defaultRunType === "automatic-trigger"
|
{defaultRunType === "automatic-trigger" ||
|
||||||
|
defaultRunType === "manual-trigger"
|
||||||
? "Set up Trigger"
|
? "Set up Trigger"
|
||||||
: "Start Task"}
|
: "Start Task"}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import {
|
import {
|
||||||
getGetV1ListGraphExecutionsInfiniteQueryOptions,
|
getGetV1ListGraphExecutionsQueryKey,
|
||||||
usePostV1ExecuteGraphAgent,
|
usePostV1ExecuteGraphAgent,
|
||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
import {
|
import {
|
||||||
getGetV2ListPresetsQueryKey,
|
getGetV2ListPresetsQueryKey,
|
||||||
usePostV2SetupTrigger,
|
usePostV2SetupTrigger,
|
||||||
} from "@/app/api/__generated__/endpoints/presets/presets";
|
} from "@/app/api/__generated__/endpoints/presets/presets";
|
||||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
|
||||||
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||||
@@ -14,7 +13,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
|||||||
import { isEmpty } from "@/lib/utils";
|
import { isEmpty } from "@/lib/utils";
|
||||||
import { analytics } from "@/services/analytics";
|
import { analytics } from "@/services/analytics";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useCallback, useMemo, useState } from "react";
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
import { showExecutionErrorToast } from "./errorHelpers";
|
import { showExecutionErrorToast } from "./errorHelpers";
|
||||||
|
|
||||||
export type RunVariant =
|
export type RunVariant =
|
||||||
@@ -25,8 +24,9 @@ export type RunVariant =
|
|||||||
|
|
||||||
interface UseAgentRunModalCallbacks {
|
interface UseAgentRunModalCallbacks {
|
||||||
onRun?: (execution: GraphExecutionMeta) => void;
|
onRun?: (execution: GraphExecutionMeta) => void;
|
||||||
onCreateSchedule?: (schedule: GraphExecutionJobInfo) => void;
|
|
||||||
onSetupTrigger?: (preset: LibraryAgentPreset) => void;
|
onSetupTrigger?: (preset: LibraryAgentPreset) => void;
|
||||||
|
initialInputValues?: Record<string, any>;
|
||||||
|
initialInputCredentials?: Record<string, any>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useAgentRunModal(
|
export function useAgentRunModal(
|
||||||
@@ -36,18 +36,28 @@ export function useAgentRunModal(
|
|||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const [isOpen, setIsOpen] = useState(false);
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
const [inputValues, setInputValues] = useState<Record<string, any>>({});
|
const [inputValues, setInputValues] = useState<Record<string, any>>(
|
||||||
|
callbacks?.initialInputValues || {},
|
||||||
|
);
|
||||||
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
const [inputCredentials, setInputCredentials] = useState<Record<string, any>>(
|
||||||
{},
|
callbacks?.initialInputCredentials || {},
|
||||||
);
|
);
|
||||||
const [presetName, setPresetName] = useState<string>("");
|
const [presetName, setPresetName] = useState<string>("");
|
||||||
const [presetDescription, setPresetDescription] = useState<string>("");
|
const [presetDescription, setPresetDescription] = useState<string>("");
|
||||||
|
|
||||||
// Determine the default run type based on agent capabilities
|
// Determine the default run type based on agent capabilities
|
||||||
const defaultRunType: RunVariant = agent.has_external_trigger
|
const defaultRunType: RunVariant = agent.trigger_setup_info
|
||||||
? "automatic-trigger"
|
? agent.trigger_setup_info.credentials_input_name
|
||||||
|
? "automatic-trigger"
|
||||||
|
: "manual-trigger"
|
||||||
: "manual";
|
: "manual";
|
||||||
|
|
||||||
|
// Update input values/credentials if template is selected/unselected
|
||||||
|
useEffect(() => {
|
||||||
|
setInputValues(callbacks?.initialInputValues || {});
|
||||||
|
setInputCredentials(callbacks?.initialInputCredentials || {});
|
||||||
|
}, [callbacks?.initialInputValues, callbacks?.initialInputCredentials]);
|
||||||
|
|
||||||
// API mutations
|
// API mutations
|
||||||
const executeGraphMutation = usePostV1ExecuteGraphAgent({
|
const executeGraphMutation = usePostV1ExecuteGraphAgent({
|
||||||
mutation: {
|
mutation: {
|
||||||
@@ -56,13 +66,11 @@ export function useAgentRunModal(
|
|||||||
toast({
|
toast({
|
||||||
title: "Agent execution started",
|
title: "Agent execution started",
|
||||||
});
|
});
|
||||||
callbacks?.onRun?.(response.data as unknown as GraphExecutionMeta);
|
|
||||||
// Invalidate runs list for this graph
|
// Invalidate runs list for this graph
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions(
|
queryKey: getGetV1ListGraphExecutionsQueryKey(agent.graph_id),
|
||||||
agent.graph_id,
|
|
||||||
).queryKey,
|
|
||||||
});
|
});
|
||||||
|
callbacks?.onRun?.(response.data);
|
||||||
analytics.sendDatafastEvent("run_agent", {
|
analytics.sendDatafastEvent("run_agent", {
|
||||||
name: agent.name,
|
name: agent.name,
|
||||||
id: agent.graph_id,
|
id: agent.graph_id,
|
||||||
@@ -81,17 +89,15 @@ export function useAgentRunModal(
|
|||||||
|
|
||||||
const setupTriggerMutation = usePostV2SetupTrigger({
|
const setupTriggerMutation = usePostV2SetupTrigger({
|
||||||
mutation: {
|
mutation: {
|
||||||
onSuccess: (response: any) => {
|
onSuccess: (response) => {
|
||||||
if (response.status === 200) {
|
if (response.status === 200) {
|
||||||
toast({
|
toast({
|
||||||
title: "Trigger setup complete",
|
title: "Trigger setup complete",
|
||||||
});
|
});
|
||||||
callbacks?.onSetupTrigger?.(response.data);
|
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListPresetsQueryKey({
|
queryKey: getGetV2ListPresetsQueryKey({ graph_id: agent.graph_id }),
|
||||||
graph_id: agent.graph_id,
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
callbacks?.onSetupTrigger?.(response.data);
|
||||||
setIsOpen(false);
|
setIsOpen(false);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -105,11 +111,13 @@ export function useAgentRunModal(
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Input schema validation
|
// Input schema validation (use trigger schema for triggered agents)
|
||||||
const agentInputSchema = useMemo(
|
const agentInputSchema = useMemo(() => {
|
||||||
() => agent.input_schema || { properties: {}, required: [] },
|
if (agent.trigger_setup_info?.config_schema) {
|
||||||
[agent.input_schema],
|
return agent.trigger_setup_info.config_schema;
|
||||||
);
|
}
|
||||||
|
return agent.input_schema || { properties: {}, required: [] };
|
||||||
|
}, [agent.input_schema, agent.trigger_setup_info]);
|
||||||
|
|
||||||
const agentInputFields = useMemo(() => {
|
const agentInputFields = useMemo(() => {
|
||||||
if (
|
if (
|
||||||
@@ -205,7 +213,10 @@ export function useAgentRunModal(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (defaultRunType === "automatic-trigger") {
|
if (
|
||||||
|
defaultRunType === "automatic-trigger" ||
|
||||||
|
defaultRunType === "manual-trigger"
|
||||||
|
) {
|
||||||
// Setup trigger
|
// Setup trigger
|
||||||
if (!presetName.trim()) {
|
if (!presetName.trim()) {
|
||||||
toast({
|
toast({
|
||||||
@@ -262,7 +273,7 @@ export function useAgentRunModal(
|
|||||||
setIsOpen,
|
setIsOpen,
|
||||||
|
|
||||||
// Run mode
|
// Run mode
|
||||||
defaultRunType,
|
defaultRunType: defaultRunType as RunVariant,
|
||||||
|
|
||||||
// Form: regular inputs
|
// Form: regular inputs
|
||||||
inputValues,
|
inputValues,
|
||||||
|
|||||||
@@ -1,17 +1,58 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { getV1GetGraphVersion } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
||||||
|
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||||
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
import { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent";
|
||||||
|
import { LibraryAgentPreset } from "@/app/api/__generated__/models/libraryAgentPreset";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { ShowMoreText } from "@/components/molecules/ShowMoreText/ShowMoreText";
|
import { ShowMoreText } from "@/components/molecules/ShowMoreText/ShowMoreText";
|
||||||
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { exportAsJSONFile } from "@/lib/utils";
|
||||||
import { formatDate } from "@/lib/utils/time";
|
import { formatDate } from "@/lib/utils/time";
|
||||||
|
import Link from "next/link";
|
||||||
import { RunAgentModal } from "../modals/RunAgentModal/RunAgentModal";
|
import { RunAgentModal } from "../modals/RunAgentModal/RunAgentModal";
|
||||||
import { RunDetailCard } from "../selected-views/RunDetailCard/RunDetailCard";
|
import { RunDetailCard } from "../selected-views/RunDetailCard/RunDetailCard";
|
||||||
import { EmptyTasksIllustration } from "./EmptyTasksIllustration";
|
import { EmptyTasksIllustration } from "./EmptyTasksIllustration";
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
|
onRun?: (run: GraphExecutionMeta) => void;
|
||||||
|
onTriggerSetup?: (preset: LibraryAgentPreset) => void;
|
||||||
|
onScheduleCreated?: (schedule: GraphExecutionJobInfo) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function EmptyTasks({ agent }: Props) {
|
export function EmptyTasks({
|
||||||
|
agent,
|
||||||
|
onRun,
|
||||||
|
onTriggerSetup,
|
||||||
|
onScheduleCreated,
|
||||||
|
}: Props) {
|
||||||
|
const { toast } = useToast();
|
||||||
|
|
||||||
|
async function handleExport() {
|
||||||
|
try {
|
||||||
|
const res = await getV1GetGraphVersion(
|
||||||
|
agent.graph_id,
|
||||||
|
agent.graph_version,
|
||||||
|
{ for_export: true },
|
||||||
|
);
|
||||||
|
if (res.status === 200) {
|
||||||
|
const filename = `${agent.name}_v${agent.graph_version}.json`;
|
||||||
|
exportAsJSONFile(res.data as any, filename);
|
||||||
|
toast({ title: "Agent exported" });
|
||||||
|
} else {
|
||||||
|
toast({ title: "Failed to export agent", variant: "destructive" });
|
||||||
|
}
|
||||||
|
} catch (e: any) {
|
||||||
|
toast({
|
||||||
|
title: "Failed to export agent",
|
||||||
|
description: e?.message,
|
||||||
|
variant: "destructive",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
const isPublished = Boolean(agent.marketplace_listing);
|
const isPublished = Boolean(agent.marketplace_listing);
|
||||||
const createdAt = formatDate(agent.created_at);
|
const createdAt = formatDate(agent.created_at);
|
||||||
const updatedAt = formatDate(agent.updated_at);
|
const updatedAt = formatDate(agent.updated_at);
|
||||||
@@ -45,7 +86,9 @@ export function EmptyTasks({ agent }: Props) {
|
|||||||
</Button>
|
</Button>
|
||||||
}
|
}
|
||||||
agent={agent}
|
agent={agent}
|
||||||
agentId={agent.id.toString()}
|
onRunCreated={onRun}
|
||||||
|
onTriggerSetup={onTriggerSetup}
|
||||||
|
onScheduleCreated={onScheduleCreated}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -93,10 +136,15 @@ export function EmptyTasks({ agent }: Props) {
|
|||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
<div className="mt-4 flex items-center gap-2">
|
<div className="mt-4 flex items-center gap-2">
|
||||||
<Button variant="secondary" size="small">
|
<Button variant="secondary" size="small" asChild>
|
||||||
Edit agent
|
<Link
|
||||||
|
href={`/build?flowID=${agent.graph_id}&flowVersion=${agent.graph_version}`}
|
||||||
|
target="_blank"
|
||||||
|
>
|
||||||
|
Edit agent
|
||||||
|
</Link>
|
||||||
</Button>
|
</Button>
|
||||||
<Button variant="secondary" size="small">
|
<Button variant="secondary" size="small" onClick={handleExport}>
|
||||||
Export agent to file
|
Export agent to file
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../helpers";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
children: React.ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function AnchorLinksWrap({ children }: Props) {
|
||||||
|
return (
|
||||||
|
<div className={cn(AGENT_LIBRARY_SECTION_PADDING_X, "hidden lg:block")}>
|
||||||
|
<nav className="flex gap-8 px-3 pb-1">{children}</nav>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -166,7 +166,7 @@ function renderMarkdown(
|
|||||||
className="prose prose-sm dark:prose-invert max-w-none"
|
className="prose prose-sm dark:prose-invert max-w-none"
|
||||||
remarkPlugins={[
|
remarkPlugins={[
|
||||||
remarkGfm, // GitHub Flavored Markdown (tables, task lists, strikethrough)
|
remarkGfm, // GitHub Flavored Markdown (tables, task lists, strikethrough)
|
||||||
remarkMath, // Math support for LaTeX
|
[remarkMath, { singleDollarTextMath: false }], // Math support for LaTeX
|
||||||
]}
|
]}
|
||||||
rehypePlugins={[
|
rehypePlugins={[
|
||||||
rehypeKatex, // Render math with KaTeX
|
rehypeKatex, // Render math with KaTeX
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
type Props = {
|
||||||
|
children: React.ReactNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function SelectedActionsWrap({ children }: Props) {
|
||||||
|
return (
|
||||||
|
<div className="my-0 ml-4 flex flex-row items-center gap-3 lg:mx-0 lg:my-4 lg:flex-col">
|
||||||
|
{children}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -13,10 +13,11 @@ import {
|
|||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList";
|
import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList";
|
||||||
import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews";
|
import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews";
|
||||||
|
import { isLargeScreen, useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { InfoIcon } from "@phosphor-icons/react";
|
import { InfoIcon } from "@phosphor-icons/react";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../../helpers";
|
|
||||||
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
|
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
|
||||||
|
import { AnchorLinksWrap } from "../AnchorLinksWrap";
|
||||||
import { LoadingSelectedContent } from "../LoadingSelectedContent";
|
import { LoadingSelectedContent } from "../LoadingSelectedContent";
|
||||||
import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
|
import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
|
||||||
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
|
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
|
||||||
@@ -46,6 +47,9 @@ export function SelectedRunView({
|
|||||||
const { run, preset, isLoading, responseError, httpError } =
|
const { run, preset, isLoading, responseError, httpError } =
|
||||||
useSelectedRunView(agent.graph_id, runId);
|
useSelectedRunView(agent.graph_id, runId);
|
||||||
|
|
||||||
|
const breakpoint = useBreakpoint();
|
||||||
|
const isLgScreenUp = isLargeScreen(breakpoint);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
pendingReviews,
|
pendingReviews,
|
||||||
isLoading: reviewsLoading,
|
isLoading: reviewsLoading,
|
||||||
@@ -89,6 +93,15 @@ export function SelectedRunView({
|
|||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<RunDetailHeader agent={agent} run={run} />
|
<RunDetailHeader agent={agent} run={run} />
|
||||||
|
|
||||||
|
{!isLgScreenUp ? (
|
||||||
|
<SelectedRunActions
|
||||||
|
agent={agent}
|
||||||
|
run={run}
|
||||||
|
onSelectRun={onSelectRun}
|
||||||
|
onClearSelectedRun={onClearSelectedRun}
|
||||||
|
/>
|
||||||
|
) : null}
|
||||||
|
|
||||||
{preset &&
|
{preset &&
|
||||||
agent.trigger_setup_info &&
|
agent.trigger_setup_info &&
|
||||||
preset.webhook_id &&
|
preset.webhook_id &&
|
||||||
@@ -100,38 +113,36 @@ export function SelectedRunView({
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Navigation Links */}
|
{/* Navigation Links */}
|
||||||
<div className={AGENT_LIBRARY_SECTION_PADDING_X}>
|
<AnchorLinksWrap>
|
||||||
<nav className="flex gap-8 px-3 pb-1">
|
{withSummary && (
|
||||||
{withSummary && (
|
|
||||||
<button
|
|
||||||
onClick={() => scrollToSection("summary")}
|
|
||||||
className={anchorStyles}
|
|
||||||
>
|
|
||||||
Summary
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
<button
|
<button
|
||||||
onClick={() => scrollToSection("output")}
|
onClick={() => scrollToSection("summary")}
|
||||||
className={anchorStyles}
|
className={anchorStyles}
|
||||||
>
|
>
|
||||||
Output
|
Summary
|
||||||
</button>
|
</button>
|
||||||
|
)}
|
||||||
|
<button
|
||||||
|
onClick={() => scrollToSection("output")}
|
||||||
|
className={anchorStyles}
|
||||||
|
>
|
||||||
|
Output
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => scrollToSection("input")}
|
||||||
|
className={anchorStyles}
|
||||||
|
>
|
||||||
|
Your input
|
||||||
|
</button>
|
||||||
|
{withReviews && (
|
||||||
<button
|
<button
|
||||||
onClick={() => scrollToSection("input")}
|
onClick={() => scrollToSection("reviews")}
|
||||||
className={anchorStyles}
|
className={anchorStyles}
|
||||||
>
|
>
|
||||||
Your input
|
Reviews ({pendingReviews.length})
|
||||||
</button>
|
</button>
|
||||||
{withReviews && (
|
)}
|
||||||
<button
|
</AnchorLinksWrap>
|
||||||
onClick={() => scrollToSection("reviews")}
|
|
||||||
className={anchorStyles}
|
|
||||||
>
|
|
||||||
Reviews ({pendingReviews.length})
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</nav>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Summary Section */}
|
{/* Summary Section */}
|
||||||
{withSummary && (
|
{withSummary && (
|
||||||
@@ -187,8 +198,8 @@ export function SelectedRunView({
|
|||||||
<RunDetailCard title="Your input">
|
<RunDetailCard title="Your input">
|
||||||
<AgentInputsReadOnly
|
<AgentInputsReadOnly
|
||||||
agent={agent}
|
agent={agent}
|
||||||
inputs={(run as any)?.inputs}
|
inputs={run?.inputs}
|
||||||
credentialInputs={(run as any)?.credential_inputs}
|
credentialInputs={run?.credential_inputs}
|
||||||
/>
|
/>
|
||||||
</RunDetailCard>
|
</RunDetailCard>
|
||||||
</div>
|
</div>
|
||||||
@@ -216,14 +227,16 @@ export function SelectedRunView({
|
|||||||
</div>
|
</div>
|
||||||
</SelectedViewLayout>
|
</SelectedViewLayout>
|
||||||
</div>
|
</div>
|
||||||
<div className="-mt-2 max-w-[3.75rem] flex-shrink-0">
|
{isLgScreenUp ? (
|
||||||
<SelectedRunActions
|
<div className="max-w-[3.75rem] flex-shrink-0">
|
||||||
agent={agent}
|
<SelectedRunActions
|
||||||
run={run}
|
agent={agent}
|
||||||
onSelectRun={onSelectRun}
|
run={run}
|
||||||
onClearSelectedRun={onClearSelectedRun}
|
onSelectRun={onSelectRun}
|
||||||
/>
|
onClearSelectedRun={onClearSelectedRun}
|
||||||
</div>
|
/>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import {
|
|||||||
StopIcon,
|
StopIcon,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
import { AgentActionsDropdown } from "../../../AgentActionsDropdown";
|
import { AgentActionsDropdown } from "../../../AgentActionsDropdown";
|
||||||
|
import { SelectedActionsWrap } from "../../../SelectedActionsWrap";
|
||||||
import { ShareRunButton } from "../../../ShareRunButton/ShareRunButton";
|
import { ShareRunButton } from "../../../ShareRunButton/ShareRunButton";
|
||||||
import { CreateTemplateModal } from "../CreateTemplateModal/CreateTemplateModal";
|
import { CreateTemplateModal } from "../CreateTemplateModal/CreateTemplateModal";
|
||||||
import { useSelectedRunActions } from "./useSelectedRunActions";
|
import { useSelectedRunActions } from "./useSelectedRunActions";
|
||||||
@@ -19,13 +20,18 @@ import { useSelectedRunActions } from "./useSelectedRunActions";
|
|||||||
type Props = {
|
type Props = {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
run: GraphExecution | undefined;
|
run: GraphExecution | undefined;
|
||||||
scheduleRecurrence?: string;
|
|
||||||
onSelectRun?: (id: string) => void;
|
onSelectRun?: (id: string) => void;
|
||||||
onClearSelectedRun?: () => void;
|
onClearSelectedRun?: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function SelectedRunActions(props: Props) {
|
export function SelectedRunActions({
|
||||||
|
agent,
|
||||||
|
run,
|
||||||
|
onSelectRun,
|
||||||
|
onClearSelectedRun,
|
||||||
|
}: Props) {
|
||||||
const {
|
const {
|
||||||
|
canRunManually,
|
||||||
handleRunAgain,
|
handleRunAgain,
|
||||||
handleStopRun,
|
handleStopRun,
|
||||||
isRunningAgain,
|
isRunningAgain,
|
||||||
@@ -36,21 +42,20 @@ export function SelectedRunActions(props: Props) {
|
|||||||
isCreateTemplateModalOpen,
|
isCreateTemplateModalOpen,
|
||||||
setIsCreateTemplateModalOpen,
|
setIsCreateTemplateModalOpen,
|
||||||
} = useSelectedRunActions({
|
} = useSelectedRunActions({
|
||||||
agentGraphId: props.agent.graph_id,
|
agentGraphId: agent.graph_id,
|
||||||
run: props.run,
|
run: run,
|
||||||
agent: props.agent,
|
agent: agent,
|
||||||
onSelectRun: props.onSelectRun,
|
onSelectRun: onSelectRun,
|
||||||
onClearSelectedRun: props.onClearSelectedRun,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const shareExecutionResultsEnabled = useGetFlag(Flag.SHARE_EXECUTION_RESULTS);
|
const shareExecutionResultsEnabled = useGetFlag(Flag.SHARE_EXECUTION_RESULTS);
|
||||||
const isRunning = props.run?.status === "RUNNING";
|
const isRunning = run?.status === "RUNNING";
|
||||||
|
|
||||||
if (!props.run || !props.agent) return null;
|
if (!run || !agent) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="my-4 flex flex-col items-center gap-3">
|
<SelectedActionsWrap>
|
||||||
{!isRunning ? (
|
{canRunManually && !isRunning ? (
|
||||||
<Button
|
<Button
|
||||||
variant="icon"
|
variant="icon"
|
||||||
size="icon"
|
size="icon"
|
||||||
@@ -102,38 +107,38 @@ export function SelectedRunActions(props: Props) {
|
|||||||
) : null}
|
) : null}
|
||||||
{shareExecutionResultsEnabled && (
|
{shareExecutionResultsEnabled && (
|
||||||
<ShareRunButton
|
<ShareRunButton
|
||||||
graphId={props.agent.graph_id}
|
graphId={agent.graph_id}
|
||||||
executionId={props.run.id}
|
executionId={run.id}
|
||||||
isShared={props.run.is_shared}
|
isShared={run.is_shared}
|
||||||
shareToken={props.run.share_token}
|
shareToken={run.share_token}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<FloatingSafeModeToggle
|
<FloatingSafeModeToggle graph={agent} variant="white" fullWidth={false} />
|
||||||
graph={props.agent}
|
{canRunManually && (
|
||||||
variant="white"
|
<>
|
||||||
fullWidth={false}
|
<Button
|
||||||
/>
|
variant="icon"
|
||||||
<Button
|
size="icon"
|
||||||
variant="icon"
|
aria-label="Save task as template"
|
||||||
size="icon"
|
onClick={() => setIsCreateTemplateModalOpen(true)}
|
||||||
aria-label="Save task as template"
|
title="Create template"
|
||||||
onClick={() => setIsCreateTemplateModalOpen(true)}
|
>
|
||||||
title="Create template"
|
<CardsThreeIcon weight="bold" size={18} className="text-zinc-700" />
|
||||||
>
|
</Button>
|
||||||
<CardsThreeIcon weight="bold" size={18} className="text-zinc-700" />
|
<CreateTemplateModal
|
||||||
</Button>
|
isOpen={isCreateTemplateModalOpen}
|
||||||
|
onClose={() => setIsCreateTemplateModalOpen(false)}
|
||||||
|
onCreate={handleCreateTemplate}
|
||||||
|
run={run}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<AgentActionsDropdown
|
<AgentActionsDropdown
|
||||||
agent={props.agent}
|
agent={agent}
|
||||||
run={props.run}
|
run={run}
|
||||||
agentGraphId={props.agent.graph_id}
|
agentGraphId={agent.graph_id}
|
||||||
onClearSelectedRun={props.onClearSelectedRun}
|
onClearSelectedRun={onClearSelectedRun}
|
||||||
/>
|
/>
|
||||||
<CreateTemplateModal
|
</SelectedActionsWrap>
|
||||||
isOpen={isCreateTemplateModalOpen}
|
|
||||||
onClose={() => setIsCreateTemplateModalOpen(false)}
|
|
||||||
onCreate={handleCreateTemplate}
|
|
||||||
run={props.run}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,15 +15,19 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
|||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
|
||||||
interface Args {
|
interface Params {
|
||||||
agentGraphId: string;
|
agentGraphId: string;
|
||||||
run?: GraphExecution;
|
run?: GraphExecution;
|
||||||
agent?: LibraryAgent;
|
agent?: LibraryAgent;
|
||||||
onSelectRun?: (id: string) => void;
|
onSelectRun?: (id: string) => void;
|
||||||
onClearSelectedRun?: () => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useSelectedRunActions(args: Args) {
|
export function useSelectedRunActions({
|
||||||
|
agentGraphId,
|
||||||
|
run,
|
||||||
|
agent,
|
||||||
|
onSelectRun,
|
||||||
|
}: Params) {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
|
||||||
@@ -31,8 +35,9 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
const [isCreateTemplateModalOpen, setIsCreateTemplateModalOpen] =
|
const [isCreateTemplateModalOpen, setIsCreateTemplateModalOpen] =
|
||||||
useState(false);
|
useState(false);
|
||||||
|
|
||||||
const canStop =
|
const canStop = run?.status === "RUNNING" || run?.status === "QUEUED";
|
||||||
args.run?.status === "RUNNING" || args.run?.status === "QUEUED";
|
|
||||||
|
const canRunManually = !agent?.trigger_setup_info;
|
||||||
|
|
||||||
const { mutateAsync: stopRun, isPending: isStopping } =
|
const { mutateAsync: stopRun, isPending: isStopping } =
|
||||||
usePostV1StopGraphExecution();
|
usePostV1StopGraphExecution();
|
||||||
@@ -46,16 +51,16 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
async function handleStopRun() {
|
async function handleStopRun() {
|
||||||
try {
|
try {
|
||||||
await stopRun({
|
await stopRun({
|
||||||
graphId: args.run?.graph_id ?? "",
|
graphId: run?.graph_id ?? "",
|
||||||
graphExecId: args.run?.id ?? "",
|
graphExecId: run?.id ?? "",
|
||||||
});
|
});
|
||||||
|
|
||||||
toast({ title: "Run stopped" });
|
toast({ title: "Run stopped" });
|
||||||
|
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions(
|
queryKey:
|
||||||
args.agentGraphId,
|
getGetV1ListGraphExecutionsInfiniteQueryOptions(agentGraphId)
|
||||||
).queryKey,
|
.queryKey,
|
||||||
});
|
});
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
toast({
|
toast({
|
||||||
@@ -70,7 +75,7 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function handleRunAgain() {
|
async function handleRunAgain() {
|
||||||
if (!args.run) {
|
if (!run) {
|
||||||
toast({
|
toast({
|
||||||
title: "Run not found",
|
title: "Run not found",
|
||||||
description: "Run not found",
|
description: "Run not found",
|
||||||
@@ -83,11 +88,11 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
toast({ title: "Run started" });
|
toast({ title: "Run started" });
|
||||||
|
|
||||||
const res = await executeRun({
|
const res = await executeRun({
|
||||||
graphId: args.run.graph_id,
|
graphId: run.graph_id,
|
||||||
graphVersion: args.run.graph_version,
|
graphVersion: run.graph_version,
|
||||||
data: {
|
data: {
|
||||||
inputs: args.run.inputs || {},
|
inputs: run.inputs || {},
|
||||||
credentials_inputs: args.run.credential_inputs || {},
|
credentials_inputs: run.credential_inputs || {},
|
||||||
source: "library",
|
source: "library",
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@@ -95,12 +100,12 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
const newRunId = res?.status === 200 ? (res?.data?.id ?? "") : "";
|
const newRunId = res?.status === 200 ? (res?.data?.id ?? "") : "";
|
||||||
|
|
||||||
await queryClient.invalidateQueries({
|
await queryClient.invalidateQueries({
|
||||||
queryKey: getGetV1ListGraphExecutionsInfiniteQueryOptions(
|
queryKey:
|
||||||
args.agentGraphId,
|
getGetV1ListGraphExecutionsInfiniteQueryOptions(agentGraphId)
|
||||||
).queryKey,
|
.queryKey,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (newRunId && args.onSelectRun) args.onSelectRun(newRunId);
|
if (newRunId && onSelectRun) onSelectRun(newRunId);
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
toast({
|
toast({
|
||||||
title: "Failed to start run",
|
title: "Failed to start run",
|
||||||
@@ -118,7 +123,7 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function handleCreateTemplate(name: string, description: string) {
|
async function handleCreateTemplate(name: string, description: string) {
|
||||||
if (!args.run) {
|
if (!run) {
|
||||||
toast({
|
toast({
|
||||||
title: "Run not found",
|
title: "Run not found",
|
||||||
description: "Cannot create template from missing run",
|
description: "Cannot create template from missing run",
|
||||||
@@ -132,7 +137,7 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
data: {
|
data: {
|
||||||
name,
|
name,
|
||||||
description,
|
description,
|
||||||
graph_execution_id: args.run.id,
|
graph_execution_id: run.id,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -141,10 +146,10 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
title: "Template created",
|
title: "Template created",
|
||||||
});
|
});
|
||||||
|
|
||||||
if (args.agent) {
|
if (agent) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListPresetsQueryKey({
|
queryKey: getGetV2ListPresetsQueryKey({
|
||||||
graph_id: args.agent.graph_id,
|
graph_id: agent.graph_id,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -164,8 +169,8 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Open in builder URL helper
|
// Open in builder URL helper
|
||||||
const openInBuilderHref = args.run
|
const openInBuilderHref = run
|
||||||
? `/build?flowID=${args.run.graph_id}&flowVersion=${args.run.graph_version}&flowExecutionID=${args.run.id}`
|
? `/build?flowID=${run.graph_id}&flowVersion=${run.graph_version}&flowExecutionID=${run.id}`
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -173,6 +178,7 @@ export function useSelectedRunActions(args: Args) {
|
|||||||
showDeleteDialog,
|
showDeleteDialog,
|
||||||
canStop,
|
canStop,
|
||||||
isStopping,
|
isStopping,
|
||||||
|
canRunManually,
|
||||||
isRunningAgain,
|
isRunningAgain,
|
||||||
handleShowDeleteDialog,
|
handleShowDeleteDialog,
|
||||||
handleStopRun,
|
handleStopRun,
|
||||||
|
|||||||
@@ -6,9 +6,10 @@ import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
||||||
|
import { isLargeScreen, useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { formatInTimezone, getTimezoneDisplayName } from "@/lib/timezone-utils";
|
import { formatInTimezone, getTimezoneDisplayName } from "@/lib/timezone-utils";
|
||||||
import { AGENT_LIBRARY_SECTION_PADDING_X } from "../../../helpers";
|
|
||||||
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
|
import { AgentInputsReadOnly } from "../../modals/AgentInputsReadOnly/AgentInputsReadOnly";
|
||||||
|
import { AnchorLinksWrap } from "../AnchorLinksWrap";
|
||||||
import { LoadingSelectedContent } from "../LoadingSelectedContent";
|
import { LoadingSelectedContent } from "../LoadingSelectedContent";
|
||||||
import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
|
import { RunDetailCard } from "../RunDetailCard/RunDetailCard";
|
||||||
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
|
import { RunDetailHeader } from "../RunDetailHeader/RunDetailHeader";
|
||||||
@@ -41,6 +42,9 @@ export function SelectedScheduleView({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const breakpoint = useBreakpoint();
|
||||||
|
const isLgScreenUp = isLargeScreen(breakpoint);
|
||||||
|
|
||||||
function scrollToSection(id: string) {
|
function scrollToSection(id: string) {
|
||||||
const element = document.getElementById(id);
|
const element = document.getElementById(id);
|
||||||
if (element) {
|
if (element) {
|
||||||
@@ -83,37 +87,42 @@ export function SelectedScheduleView({
|
|||||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||||
<SelectedViewLayout agentName={agent.name} agentId={agent.id}>
|
<SelectedViewLayout agentName={agent.name} agentId={agent.id}>
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<div className="flex w-full items-center justify-between">
|
<div className="flex w-full flex-col gap-0">
|
||||||
<div className="flex w-full flex-col gap-0">
|
<RunDetailHeader
|
||||||
<RunDetailHeader
|
agent={agent}
|
||||||
agent={agent}
|
run={undefined}
|
||||||
run={undefined}
|
scheduleRecurrence={
|
||||||
scheduleRecurrence={
|
schedule
|
||||||
schedule
|
? `${humanizeCronExpression(schedule.cron || "")} · ${getTimezoneDisplayName(schedule.timezone || userTzRes || "UTC")}`
|
||||||
? `${humanizeCronExpression(schedule.cron || "")} · ${getTimezoneDisplayName(schedule.timezone || userTzRes || "UTC")}`
|
: undefined
|
||||||
: undefined
|
}
|
||||||
}
|
/>
|
||||||
/>
|
{schedule && !isLgScreenUp ? (
|
||||||
</div>
|
<div className="mt-4">
|
||||||
|
<SelectedScheduleActions
|
||||||
|
agent={agent}
|
||||||
|
scheduleId={schedule.id}
|
||||||
|
onDeleted={onClearSelectedRun}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Navigation Links */}
|
{/* Navigation Links */}
|
||||||
<div className={AGENT_LIBRARY_SECTION_PADDING_X}>
|
<AnchorLinksWrap>
|
||||||
<nav className="flex gap-8 px-3 pb-1">
|
<button
|
||||||
<button
|
onClick={() => scrollToSection("schedule")}
|
||||||
onClick={() => scrollToSection("schedule")}
|
className={anchorStyles}
|
||||||
className={anchorStyles}
|
>
|
||||||
>
|
Schedule
|
||||||
Schedule
|
</button>
|
||||||
</button>
|
<button
|
||||||
<button
|
onClick={() => scrollToSection("input")}
|
||||||
onClick={() => scrollToSection("input")}
|
className={anchorStyles}
|
||||||
className={anchorStyles}
|
>
|
||||||
>
|
Your input
|
||||||
Your input
|
</button>
|
||||||
</button>
|
</AnchorLinksWrap>
|
||||||
</nav>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Schedule Section */}
|
{/* Schedule Section */}
|
||||||
<div id="schedule" className="scroll-mt-4">
|
<div id="schedule" className="scroll-mt-4">
|
||||||
@@ -172,10 +181,6 @@ export function SelectedScheduleView({
|
|||||||
<div id="input" className="scroll-mt-4">
|
<div id="input" className="scroll-mt-4">
|
||||||
<RunDetailCard title="Your input">
|
<RunDetailCard title="Your input">
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
{/* {// TODO: re-enable edit inputs modal once the API supports it */}
|
|
||||||
{/* {schedule && Object.keys(schedule.input_data).length > 0 && (
|
|
||||||
<EditInputsModal agent={agent} schedule={schedule} />
|
|
||||||
)} */}
|
|
||||||
<AgentInputsReadOnly
|
<AgentInputsReadOnly
|
||||||
agent={agent}
|
agent={agent}
|
||||||
inputs={schedule?.input_data}
|
inputs={schedule?.input_data}
|
||||||
@@ -187,8 +192,8 @@ export function SelectedScheduleView({
|
|||||||
</div>
|
</div>
|
||||||
</SelectedViewLayout>
|
</SelectedViewLayout>
|
||||||
</div>
|
</div>
|
||||||
{schedule ? (
|
{schedule && isLgScreenUp ? (
|
||||||
<div className="-mt-2 max-w-[3.75rem] flex-shrink-0">
|
<div className="max-w-[3.75rem] flex-shrink-0">
|
||||||
<SelectedScheduleActions
|
<SelectedScheduleActions
|
||||||
agent={agent}
|
agent={agent}
|
||||||
scheduleId={schedule.id}
|
scheduleId={schedule.id}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { Button } from "@/components/atoms/Button/Button";
|
|||||||
import { EyeIcon } from "@phosphor-icons/react";
|
import { EyeIcon } from "@phosphor-icons/react";
|
||||||
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
|
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
|
||||||
import { useScheduleDetailHeader } from "../../RunDetailHeader/useScheduleDetailHeader";
|
import { useScheduleDetailHeader } from "../../RunDetailHeader/useScheduleDetailHeader";
|
||||||
|
import { SelectedActionsWrap } from "../../SelectedActionsWrap";
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
@@ -19,7 +20,7 @@ export function SelectedScheduleActions({ agent, scheduleId }: Props) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="my-4 flex flex-col items-center gap-3">
|
<SelectedActionsWrap>
|
||||||
{openInBuilderHref && (
|
{openInBuilderHref && (
|
||||||
<Button
|
<Button
|
||||||
variant="icon"
|
variant="icon"
|
||||||
@@ -32,7 +33,7 @@ export function SelectedScheduleActions({ agent, scheduleId }: Props) {
|
|||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
<AgentActionsDropdown agent={agent} scheduleId={scheduleId} />
|
<AgentActionsDropdown agent={agent} scheduleId={scheduleId} />
|
||||||
</div>
|
</SelectedActionsWrap>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ export function SelectedTemplateView({
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const templateOrTrigger = agent.trigger_setup_info ? "Trigger" : "Template";
|
||||||
const hasWebhook = !!template.webhook_id && template.webhook;
|
const hasWebhook = !!template.webhook_id && template.webhook;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -111,14 +112,14 @@ export function SelectedTemplateView({
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<RunDetailCard title="Template Details">
|
<RunDetailCard title={`${templateOrTrigger} Details`}>
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<Input
|
<Input
|
||||||
id="template-name"
|
id="template-name"
|
||||||
label="Name"
|
label="Name"
|
||||||
value={name}
|
value={name}
|
||||||
onChange={(e) => setName(e.target.value)}
|
onChange={(e) => setName(e.target.value)}
|
||||||
placeholder="Enter template name"
|
placeholder={`Enter ${templateOrTrigger.toLowerCase()} name`}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Input
|
<Input
|
||||||
@@ -128,7 +129,7 @@ export function SelectedTemplateView({
|
|||||||
rows={3}
|
rows={3}
|
||||||
value={description}
|
value={description}
|
||||||
onChange={(e) => setDescription(e.target.value)}
|
onChange={(e) => setDescription(e.target.value)}
|
||||||
placeholder="Enter template description"
|
placeholder={`Enter ${templateOrTrigger.toLowerCase()} description`}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</RunDetailCard>
|
</RunDetailCard>
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
|||||||
import { FloppyDiskIcon, PlayIcon, TrashIcon } from "@phosphor-icons/react";
|
import { FloppyDiskIcon, PlayIcon, TrashIcon } from "@phosphor-icons/react";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
@@ -134,6 +135,7 @@ export function SelectedTemplateActions({
|
|||||||
<TrashIcon weight="bold" size={18} />
|
<TrashIcon weight="bold" size={18} />
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
<AgentActionsDropdown agent={agent} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<Dialog
|
<Dialog
|
||||||
|
|||||||
@@ -138,11 +138,21 @@ export function useSelectedTemplateView({
|
|||||||
}
|
}
|
||||||
|
|
||||||
function handleStartTask() {
|
function handleStartTask() {
|
||||||
|
if (!query.data) return;
|
||||||
|
|
||||||
|
const inputsChanged =
|
||||||
|
JSON.stringify(inputs) !== JSON.stringify(query.data.inputs || {});
|
||||||
|
|
||||||
|
const credentialsChanged =
|
||||||
|
JSON.stringify(credentials) !==
|
||||||
|
JSON.stringify(query.data.credentials || {});
|
||||||
|
|
||||||
|
// Use changed unpersisted inputs if applicable
|
||||||
executeMutation.mutate({
|
executeMutation.mutate({
|
||||||
presetId: templateId,
|
presetId: templateId,
|
||||||
data: {
|
data: {
|
||||||
inputs: {},
|
inputs: inputsChanged ? inputs : undefined,
|
||||||
credential_inputs: {},
|
credential_inputs: credentialsChanged ? credentials : undefined,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import { useToast } from "@/components/molecules/Toast/use-toast";
|
|||||||
import { FloppyDiskIcon, TrashIcon } from "@phosphor-icons/react";
|
import { FloppyDiskIcon, TrashIcon } from "@phosphor-icons/react";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
|
import { AgentActionsDropdown } from "../../AgentActionsDropdown";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
agent: LibraryAgent;
|
agent: LibraryAgent;
|
||||||
@@ -111,6 +112,7 @@ export function SelectedTriggerActions({
|
|||||||
<TrashIcon weight="bold" size={18} />
|
<TrashIcon weight="bold" size={18} />
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
<AgentActionsDropdown agent={agent} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<Dialog
|
<Dialog
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ export function SelectedViewLayout(props: Props) {
|
|||||||
return (
|
return (
|
||||||
<SectionWrap className="relative mb-3 flex min-h-0 flex-1 flex-col">
|
<SectionWrap className="relative mb-3 flex min-h-0 flex-1 flex-col">
|
||||||
<div
|
<div
|
||||||
className={`${AGENT_LIBRARY_SECTION_PADDING_X} flex-shrink-0 border-b border-zinc-100 pb-4`}
|
className={`${AGENT_LIBRARY_SECTION_PADDING_X} flex-shrink-0 border-b border-zinc-100 pb-0 lg:pb-4`}
|
||||||
>
|
>
|
||||||
<Breadcrumbs
|
<Breadcrumbs
|
||||||
items={[
|
items={[
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user