Compare commits
71 Commits
ntindle-pa
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4729dcc898 | ||
|
|
3f19cba28f | ||
|
|
a978e91271 | ||
|
|
f283e6c514 | ||
|
|
9fc2101e7e | ||
|
|
634f826d82 | ||
|
|
6d6bf308fc | ||
|
|
dd84fb5c66 | ||
|
|
33679f3ffe | ||
|
|
fc8c5ccbb6 | ||
|
|
7d2ab61546 | ||
|
|
c2f11dbcfa | ||
|
|
f82adeb959 | ||
|
|
6f08a1cca7 | ||
|
|
1ddf92eed4 | ||
|
|
4c0dd27157 | ||
|
|
17fcf68f2e | ||
|
|
381558342a | ||
|
|
1fdc02467b | ||
|
|
f262bb9307 | ||
|
|
5a6978b07d | ||
|
|
339ec733cb | ||
|
|
6575b655f0 | ||
|
|
7c2df24d7c | ||
|
|
23eafa178c | ||
|
|
27fccdbf31 | ||
|
|
fb8fbc9d1f | ||
|
|
6a86e70fd6 | ||
|
|
6a2d7e0fb0 | ||
|
|
3d6ea3088e | ||
|
|
64b4480b1e | ||
|
|
f490b01abb | ||
|
|
e56a4a135d | ||
|
|
e70c970ab6 | ||
|
|
3bbce71678 | ||
|
|
34fbf4377f | ||
|
|
f682ef885a | ||
|
|
2ffd249aac | ||
|
|
986245ec43 | ||
|
|
f89717153f | ||
|
|
5da41e0753 | ||
|
|
cddeb185a8 | ||
|
|
08a3fd6d26 | ||
|
|
39b30bc82c | ||
|
|
2df0e2b750 | ||
|
|
925f249ce1 | ||
|
|
e8cf3edbf4 | ||
|
|
dc03ea718c | ||
|
|
dbee580d80 | ||
|
|
0325ec0a2c | ||
|
|
3952a1a226 | ||
|
|
cfc975d39b | ||
|
|
46e0f6cc45 | ||
|
|
c03af5c196 | ||
|
|
00cbfb8f80 | ||
|
|
3beafae955 | ||
|
|
9cd186a2f3 | ||
|
|
dcf26bd3d4 | ||
|
|
b97f097c9d | ||
|
|
5be6987d58 | ||
|
|
4bcc73f784 | ||
|
|
c8240a4d6b | ||
|
|
82618aede0 | ||
|
|
c4483fa6c7 | ||
|
|
c2af8c1a6a | ||
|
|
483c399812 | ||
|
|
260dd526c9 | ||
|
|
75a159db01 | ||
|
|
62032e6584 | ||
|
|
105d5dc7e9 | ||
|
|
92515b3683 |
2
.github/workflows/claude-dependabot.yml
vendored
@@ -311,7 +311,7 @@ jobs:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
custom_system_prompt: |
|
||||
prompt: |
|
||||
You are Claude, an AI assistant specialized in reviewing Dependabot dependency update PRs.
|
||||
|
||||
Your primary tasks are:
|
||||
|
||||
3
.github/workflows/claude.yml
vendored
@@ -319,6 +319,7 @@ jobs:
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*), Bash(gh pr edit:*)"
|
||||
--model opus
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
@@ -5,6 +5,13 @@ on:
|
||||
branches: [ dev ]
|
||||
paths:
|
||||
- 'autogpt_platform/**'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: 'Git ref (branch/tag) of AutoGPT to deploy'
|
||||
required: true
|
||||
default: 'master'
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -19,6 +26,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.git_ref || github.ref_name }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -48,4 +57,4 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_dev
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: '{"ref": "${{ github.event.inputs.git_ref || github.ref }}", "repository": "${{ github.repository }}"}'
|
||||
|
||||
@@ -3,6 +3,7 @@ name: AutoGPT Platform - Deploy Prod Environment
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
@@ -17,6 +18,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name || 'master' }}
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -36,7 +39,7 @@ jobs:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
DIRECT_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
@@ -47,4 +50,5 @@ jobs:
|
||||
token: ${{ secrets.DEPLOY_TOKEN }}
|
||||
repository: Significant-Gravitas/AutoGPT_cloud_infrastructure
|
||||
event-type: build_deploy_prod
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "repository": "${{ github.repository }}"}'
|
||||
client-payload: |
|
||||
{"ref": "${{ github.ref_name || 'master' }}", "repository": "${{ github.repository }}"}
|
||||
5
.github/workflows/platform-backend-ci.yml
vendored
@@ -37,9 +37,7 @@ jobs:
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: bitnami/redis:6.2
|
||||
env:
|
||||
REDIS_PASSWORD: testpassword
|
||||
image: redis:latest
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
@@ -204,7 +202,6 @@ jobs:
|
||||
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
REDIS_HOST: "localhost"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PASSWORD: "testpassword"
|
||||
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw=" # DO NOT USE IN PRODUCTION!!
|
||||
|
||||
env:
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
raw: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
hash: str
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
PREFIX: str = "agpt_"
|
||||
PREFIX_LENGTH: int = 8
|
||||
POSTFIX_LENGTH: int = 8
|
||||
|
||||
def generate_api_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with all its parts."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
return APIKeyContainer(
|
||||
raw=raw_key,
|
||||
prefix=raw_key[: self.PREFIX_LENGTH],
|
||||
postfix=raw_key[-self.POSTFIX_LENGTH :],
|
||||
hash=hashlib.sha256(raw_key.encode()).hexdigest(),
|
||||
)
|
||||
|
||||
def verify_api_key(self, provided_key: str, stored_hash: str) -> bool:
|
||||
"""Verify if a provided API key matches the stored hash."""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(provided_hash, stored_hash)
|
||||
@@ -0,0 +1,78 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
from typing import NamedTuple
|
||||
|
||||
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
|
||||
|
||||
|
||||
class APIKeyContainer(NamedTuple):
|
||||
"""Container for API key parts."""
|
||||
|
||||
key: str
|
||||
head: str
|
||||
tail: str
|
||||
hash: str
|
||||
salt: str
|
||||
|
||||
|
||||
class APIKeySmith:
|
||||
PREFIX: str = "agpt_"
|
||||
HEAD_LENGTH: int = 8
|
||||
TAIL_LENGTH: int = 8
|
||||
|
||||
def generate_key(self) -> APIKeyContainer:
|
||||
"""Generate a new API key with secure hashing."""
|
||||
raw_key = f"{self.PREFIX}{secrets.token_urlsafe(32)}"
|
||||
hash, salt = self.hash_key(raw_key)
|
||||
|
||||
return APIKeyContainer(
|
||||
key=raw_key,
|
||||
head=raw_key[: self.HEAD_LENGTH],
|
||||
tail=raw_key[-self.TAIL_LENGTH :],
|
||||
hash=hash,
|
||||
salt=salt,
|
||||
)
|
||||
|
||||
def verify_key(
|
||||
self, provided_key: str, known_hash: str, known_salt: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Verify an API key against a known hash (+ salt).
|
||||
Supports verifying both legacy SHA256 and secure Scrypt hashes.
|
||||
"""
|
||||
if not provided_key.startswith(self.PREFIX):
|
||||
return False
|
||||
|
||||
# Handle legacy SHA256 hashes (migration support)
|
||||
if known_salt is None:
|
||||
legacy_hash = hashlib.sha256(provided_key.encode()).hexdigest()
|
||||
return secrets.compare_digest(legacy_hash, known_hash)
|
||||
|
||||
try:
|
||||
salt_bytes = bytes.fromhex(known_salt)
|
||||
provided_hash = self._hash_key_with_salt(provided_key, salt_bytes)
|
||||
return secrets.compare_digest(provided_hash, known_hash)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def hash_key(self, raw_key: str) -> tuple[str, str]:
|
||||
"""Migrate a legacy hash to secure hash format."""
|
||||
salt = self._generate_salt()
|
||||
hash = self._hash_key_with_salt(raw_key, salt)
|
||||
return hash, salt.hex()
|
||||
|
||||
def _generate_salt(self) -> bytes:
|
||||
"""Generate a random salt for hashing."""
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
def _hash_key_with_salt(self, raw_key: str, salt: bytes) -> str:
|
||||
"""Hash API key using Scrypt with salt."""
|
||||
kdf = Scrypt(
|
||||
length=32,
|
||||
salt=salt,
|
||||
n=2**14, # CPU/memory cost parameter
|
||||
r=8, # Block size parameter
|
||||
p=1, # Parallelization parameter
|
||||
)
|
||||
key_hash = kdf.derive(raw_key.encode())
|
||||
return key_hash.hex()
|
||||
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
|
||||
|
||||
def test_generate_api_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
assert key.key.startswith(keysmith.PREFIX)
|
||||
assert key.head == key.key[: keysmith.HEAD_LENGTH]
|
||||
assert key.tail == key.key[-keysmith.TAIL_LENGTH :]
|
||||
assert len(key.hash) == 64 # 32 bytes hex encoded
|
||||
assert len(key.salt) == 64 # 32 bytes hex encoded
|
||||
|
||||
|
||||
def test_verify_new_secure_key():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test correct key validates
|
||||
assert keysmith.verify_key(key.key, key.hash, key.salt) is True
|
||||
|
||||
# Test wrong key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey123"
|
||||
assert keysmith.verify_key(wrong_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_verify_legacy_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}legacykey123"
|
||||
legacy_hash = hashlib.sha256(legacy_key.encode()).hexdigest()
|
||||
|
||||
# Test legacy key validates without salt
|
||||
assert keysmith.verify_key(legacy_key, legacy_hash) is True
|
||||
|
||||
# Test wrong legacy key fails
|
||||
wrong_key = f"{keysmith.PREFIX}wronglegacy"
|
||||
assert keysmith.verify_key(wrong_key, legacy_hash) is False
|
||||
|
||||
|
||||
def test_rehash_existing_key():
|
||||
keysmith = APIKeySmith()
|
||||
legacy_key = f"{keysmith.PREFIX}migratekey123"
|
||||
|
||||
# Migrate the legacy key
|
||||
new_hash, new_salt = keysmith.hash_key(legacy_key)
|
||||
|
||||
# Verify migrated key works
|
||||
assert keysmith.verify_key(legacy_key, new_hash, new_salt) is True
|
||||
|
||||
# Verify different key fails with migrated hash
|
||||
wrong_key = f"{keysmith.PREFIX}wrongkey"
|
||||
assert keysmith.verify_key(wrong_key, new_hash, new_salt) is False
|
||||
|
||||
|
||||
def test_invalid_key_prefix():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Test key without proper prefix fails
|
||||
invalid_key = "invalid_prefix_key"
|
||||
assert keysmith.verify_key(invalid_key, key.hash, key.salt) is False
|
||||
|
||||
|
||||
def test_secure_hash_requires_salt():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Secure hash without salt should fail
|
||||
assert keysmith.verify_key(key.key, key.hash) is False
|
||||
|
||||
|
||||
def test_invalid_salt_format():
|
||||
keysmith = APIKeySmith()
|
||||
key = keysmith.generate_key()
|
||||
|
||||
# Invalid salt format should fail gracefully
|
||||
assert keysmith.verify_key(key.key, key.hash, "invalid_hex") is False
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -13,8 +15,8 @@ class RateLimitSettings(BaseSettings):
|
||||
default="6379", description="Redis port", validation_alias="REDIS_PORT"
|
||||
)
|
||||
|
||||
redis_password: str = Field(
|
||||
default="password",
|
||||
redis_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Redis password",
|
||||
validation_alias="REDIS_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ class RateLimiter:
|
||||
self,
|
||||
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
|
||||
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
|
||||
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
|
||||
redis_password: str | None = RATE_LIMIT_SETTINGS.redis_password,
|
||||
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
|
||||
):
|
||||
self.redis = Redis(
|
||||
|
||||
36
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1002,6 +1002,18 @@ dynamodb = ["boto3 (>=1.9.71)"]
|
||||
redis = ["redis (>=2.10.5)"]
|
||||
test-filesource = ["pyyaml (>=5.3.1)", "watchdog (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"},
|
||||
{file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.35.0"
|
||||
@@ -1347,6 +1359,27 @@ files = [
|
||||
{file = "pyrfc3339-2.0.1.tar.gz", hash = "sha256:e47843379ea35c1296c3b6c67a948a1a490ae0584edfcbdea0eaffb5dd29960b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.404"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419"},
|
||||
{file = "pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
@@ -1740,7 +1773,6 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {dev = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
@@ -1897,4 +1929,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "d841f62f95180f6ad63ce82ed8e62aa201b9bf89242cc9299ae0f26ff1f72136"
|
||||
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
|
||||
|
||||
@@ -9,6 +9,7 @@ packages = [{ include = "autogpt_libs" }]
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -21,11 +22,12 @@ supabase = "^2.16.0"
|
||||
uvicorn = "^0.35.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.12.11"
|
||||
pyright = "^1.1.404"
|
||||
pytest = "^8.4.1"
|
||||
pytest-asyncio = "^1.1.0"
|
||||
pytest-mock = "^3.14.1"
|
||||
pytest-cov = "^6.2.1"
|
||||
ruff = "^0.12.11"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -21,7 +21,7 @@ PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
# REDIS_PASSWORD=
|
||||
|
||||
# RabbitMQ Credentials
|
||||
RABBITMQ_DEFAULT_USER=rabbitmq_user_default
|
||||
|
||||
@@ -661,6 +661,167 @@ async def update_field(
|
||||
#################################################################
|
||||
|
||||
|
||||
async def get_table_schema(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
table_id_or_name: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the schema for a specific table, including all field definitions.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The base ID
|
||||
table_id_or_name: The table ID or name
|
||||
|
||||
Returns:
|
||||
Dict containing table schema with fields information
|
||||
"""
|
||||
# First get all tables to find the right one
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
tables = data.get("tables", [])
|
||||
|
||||
# Find the matching table
|
||||
for table in tables:
|
||||
if table.get("id") == table_id_or_name or table.get("name") == table_id_or_name:
|
||||
return table
|
||||
|
||||
raise ValueError(f"Table '{table_id_or_name}' not found in base '{base_id}'")
|
||||
|
||||
|
||||
def get_empty_value_for_field(field_type: str) -> Any:
|
||||
"""
|
||||
Return the appropriate empty value for a given Airtable field type.
|
||||
|
||||
Args:
|
||||
field_type: The Airtable field type
|
||||
|
||||
Returns:
|
||||
The appropriate empty value for that field type
|
||||
"""
|
||||
# Fields that should be false when empty
|
||||
if field_type == "checkbox":
|
||||
return False
|
||||
|
||||
# Fields that should be empty arrays
|
||||
if field_type in [
|
||||
"multipleSelects",
|
||||
"multipleRecordLinks",
|
||||
"multipleAttachments",
|
||||
"multipleLookupValues",
|
||||
"multipleCollaborators",
|
||||
]:
|
||||
return []
|
||||
|
||||
# Fields that should be 0 when empty (numeric types)
|
||||
if field_type in [
|
||||
"number",
|
||||
"percent",
|
||||
"currency",
|
||||
"rating",
|
||||
"duration",
|
||||
"count",
|
||||
"autoNumber",
|
||||
]:
|
||||
return 0
|
||||
|
||||
# Fields that should be empty strings
|
||||
if field_type in [
|
||||
"singleLineText",
|
||||
"multilineText",
|
||||
"email",
|
||||
"url",
|
||||
"phoneNumber",
|
||||
"richText",
|
||||
"barcode",
|
||||
]:
|
||||
return ""
|
||||
|
||||
# Everything else gets null (dates, single selects, formulas, etc.)
|
||||
return None
|
||||
|
||||
|
||||
async def normalize_records(
|
||||
records: list[dict],
|
||||
table_schema: dict,
|
||||
include_field_metadata: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Normalize Airtable records to include all fields with proper empty values.
|
||||
|
||||
Args:
|
||||
records: List of record objects from Airtable API
|
||||
table_schema: Table schema containing field definitions
|
||||
include_field_metadata: Whether to include field metadata in response
|
||||
|
||||
Returns:
|
||||
Dict with normalized records and optionally field metadata
|
||||
"""
|
||||
fields = table_schema.get("fields", [])
|
||||
|
||||
# Normalize each record
|
||||
normalized_records = []
|
||||
for record in records:
|
||||
normalized = {
|
||||
"id": record.get("id"),
|
||||
"createdTime": record.get("createdTime"),
|
||||
"fields": {},
|
||||
}
|
||||
|
||||
# Add existing fields
|
||||
existing_fields = record.get("fields", {})
|
||||
|
||||
# Add all fields from schema, using empty values for missing ones
|
||||
for field in fields:
|
||||
field_name = field["name"]
|
||||
field_type = field["type"]
|
||||
|
||||
if field_name in existing_fields:
|
||||
# Field exists, use its value
|
||||
normalized["fields"][field_name] = existing_fields[field_name]
|
||||
else:
|
||||
# Field is missing, add appropriate empty value
|
||||
normalized["fields"][field_name] = get_empty_value_for_field(field_type)
|
||||
|
||||
normalized_records.append(normalized)
|
||||
|
||||
# Build result dictionary
|
||||
if include_field_metadata:
|
||||
field_metadata = {}
|
||||
for field in fields:
|
||||
metadata = {"type": field["type"], "id": field["id"]}
|
||||
|
||||
# Add type-specific metadata
|
||||
options = field.get("options", {})
|
||||
if field["type"] == "currency" and "symbol" in options:
|
||||
metadata["symbol"] = options["symbol"]
|
||||
metadata["precision"] = options.get("precision", 2)
|
||||
elif field["type"] == "duration" and "durationFormat" in options:
|
||||
metadata["format"] = options["durationFormat"]
|
||||
elif field["type"] == "percent" and "precision" in options:
|
||||
metadata["precision"] = options["precision"]
|
||||
elif (
|
||||
field["type"] in ["singleSelect", "multipleSelects"]
|
||||
and "choices" in options
|
||||
):
|
||||
metadata["choices"] = [choice["name"] for choice in options["choices"]]
|
||||
elif field["type"] == "rating" and "max" in options:
|
||||
metadata["max"] = options["max"]
|
||||
metadata["icon"] = options.get("icon", "star")
|
||||
metadata["color"] = options.get("color", "yellowBright")
|
||||
|
||||
field_metadata[field["name"]] = metadata
|
||||
|
||||
return {"records": normalized_records, "field_metadata": field_metadata}
|
||||
else:
|
||||
return {"records": normalized_records}
|
||||
|
||||
|
||||
async def list_records(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
@@ -1249,3 +1410,26 @@ async def list_bases(
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
async def get_base_tables(
|
||||
credentials: Credentials,
|
||||
base_id: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all tables for a specific base.
|
||||
|
||||
Args:
|
||||
credentials: Airtable API credentials
|
||||
base_id: The ID of the base
|
||||
|
||||
Returns:
|
||||
list[dict]: List of table objects with their schemas
|
||||
"""
|
||||
response = await Requests().get(
|
||||
f"https://api.airtable.com/v0/meta/bases/{base_id}/tables",
|
||||
headers={"Authorization": credentials.auth_header()},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
return data.get("tables", [])
|
||||
|
||||
@@ -14,13 +14,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import create_base, list_bases
|
||||
from ._api import create_base, get_base_tables, list_bases
|
||||
from ._config import airtable
|
||||
|
||||
|
||||
class AirtableCreateBaseBlock(Block):
|
||||
"""
|
||||
Creates a new base in an Airtable workspace.
|
||||
Creates a new base in an Airtable workspace, or returns existing base if one with the same name exists.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
@@ -31,6 +31,10 @@ class AirtableCreateBaseBlock(Block):
|
||||
description="The workspace ID where the base will be created"
|
||||
)
|
||||
name: str = SchemaField(description="The name of the new base")
|
||||
find_existing: bool = SchemaField(
|
||||
description="If true, return existing base with same name instead of creating duplicate",
|
||||
default=True,
|
||||
)
|
||||
tables: list[dict] = SchemaField(
|
||||
description="At least one table and field must be specified. Array of table objects to create in the base. Each table should have 'name' and 'fields' properties",
|
||||
default=[
|
||||
@@ -50,14 +54,18 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
base_id: str = SchemaField(description="The ID of the created base")
|
||||
base_id: str = SchemaField(description="The ID of the created or found base")
|
||||
tables: list[dict] = SchemaField(description="Array of table objects")
|
||||
table: dict = SchemaField(description="A single table object")
|
||||
was_created: bool = SchemaField(
|
||||
description="True if a new base was created, False if existing was found",
|
||||
default=True,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f59b88a8-54ce-4676-a508-fd614b4e8dce",
|
||||
description="Create a new base in Airtable",
|
||||
description="Create or find a base in Airtable",
|
||||
categories={BlockCategory.DATA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
@@ -66,6 +74,31 @@ class AirtableCreateBaseBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# If find_existing is true, check if a base with this name already exists
|
||||
if input_data.find_existing:
|
||||
# List all bases to check for existing one with same name
|
||||
# Note: Airtable API doesn't have a direct search, so we need to list and filter
|
||||
existing_bases = await list_bases(credentials)
|
||||
|
||||
for base in existing_bases.get("bases", []):
|
||||
if base.get("name") == input_data.name:
|
||||
# Base already exists, return it
|
||||
base_id = base.get("id")
|
||||
yield "base_id", base_id
|
||||
yield "was_created", False
|
||||
|
||||
# Get the tables for this base
|
||||
try:
|
||||
tables = await get_base_tables(credentials, base_id)
|
||||
yield "tables", tables
|
||||
for table in tables:
|
||||
yield "table", table
|
||||
except Exception:
|
||||
# If we can't get tables, return empty list
|
||||
yield "tables", []
|
||||
return
|
||||
|
||||
# No existing base found or find_existing is false, create new one
|
||||
data = await create_base(
|
||||
credentials,
|
||||
input_data.workspace_id,
|
||||
@@ -74,6 +107,7 @@ class AirtableCreateBaseBlock(Block):
|
||||
)
|
||||
|
||||
yield "base_id", data.get("id", None)
|
||||
yield "was_created", True
|
||||
yield "tables", data.get("tables", [])
|
||||
for table in data.get("tables", []):
|
||||
yield "table", table
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Airtable record operation blocks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
@@ -18,7 +18,9 @@ from ._api import (
|
||||
create_record,
|
||||
delete_multiple_records,
|
||||
get_record,
|
||||
get_table_schema,
|
||||
list_records,
|
||||
normalize_records,
|
||||
update_multiple_records,
|
||||
)
|
||||
from ._config import airtable
|
||||
@@ -54,12 +56,24 @@ class AirtableListRecordsBlock(Block):
|
||||
return_fields: list[str] = SchemaField(
|
||||
description="Specific fields to return (comma-separated)", default=[]
|
||||
)
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
records: list[dict] = SchemaField(description="Array of record objects")
|
||||
offset: Optional[str] = SchemaField(
|
||||
description="Offset for next page (null if no more records)", default=None
|
||||
)
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -73,6 +87,7 @@ class AirtableListRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
data = await list_records(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -88,8 +103,33 @@ class AirtableListRecordsBlock(Block):
|
||||
fields=input_data.return_fields if input_data.return_fields else None,
|
||||
)
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
yield "offset", data.get("offset", None)
|
||||
records = data.get("records", [])
|
||||
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
records,
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
yield "records", normalized_data["records"]
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "records", records
|
||||
yield "offset", data.get("offset", None)
|
||||
|
||||
|
||||
class AirtableGetRecordBlock(Block):
|
||||
@@ -104,11 +144,23 @@ class AirtableGetRecordBlock(Block):
|
||||
base_id: str = SchemaField(description="The Airtable base ID")
|
||||
table_id_or_name: str = SchemaField(description="Table ID or name")
|
||||
record_id: str = SchemaField(description="The record ID to retrieve")
|
||||
normalize_output: bool = SchemaField(
|
||||
description="Normalize output to include all fields with proper empty values (disable to skip schema fetch and get raw Airtable response)",
|
||||
default=True,
|
||||
)
|
||||
include_field_metadata: bool = SchemaField(
|
||||
description="Include field type and configuration metadata (requires normalize_output=true)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: str = SchemaField(description="The record ID")
|
||||
fields: dict = SchemaField(description="The record fields")
|
||||
created_time: str = SchemaField(description="The record created time")
|
||||
field_metadata: Optional[dict] = SchemaField(
|
||||
description="Field type and configuration metadata (only when include_field_metadata=true)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -122,6 +174,7 @@ class AirtableGetRecordBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
|
||||
record = await get_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -129,9 +182,34 @@ class AirtableGetRecordBlock(Block):
|
||||
input_data.record_id,
|
||||
)
|
||||
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
# Normalize output if requested
|
||||
if input_data.normalize_output:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the single record (wrap in list and unwrap result)
|
||||
normalized_data = await normalize_records(
|
||||
[record],
|
||||
table_schema,
|
||||
include_field_metadata=input_data.include_field_metadata,
|
||||
)
|
||||
|
||||
normalized_record = normalized_data["records"][0]
|
||||
yield "id", normalized_record.get("id", None)
|
||||
yield "fields", normalized_record.get("fields", None)
|
||||
yield "created_time", normalized_record.get("createdTime", None)
|
||||
|
||||
if (
|
||||
input_data.include_field_metadata
|
||||
and "field_metadata" in normalized_data
|
||||
):
|
||||
yield "field_metadata", normalized_data["field_metadata"]
|
||||
else:
|
||||
yield "id", record.get("id", None)
|
||||
yield "fields", record.get("fields", None)
|
||||
yield "created_time", record.get("createdTime", None)
|
||||
|
||||
|
||||
class AirtableCreateRecordsBlock(Block):
|
||||
@@ -148,6 +226,10 @@ class AirtableCreateRecordsBlock(Block):
|
||||
records: list[dict] = SchemaField(
|
||||
description="Array of records to create (each with 'fields' object)"
|
||||
)
|
||||
skip_normalization: bool = SchemaField(
|
||||
description="Skip output normalization to get raw Airtable response (faster but may have missing fields)",
|
||||
default=False,
|
||||
)
|
||||
typecast: bool = SchemaField(
|
||||
description="Automatically convert string values to appropriate types",
|
||||
default=False,
|
||||
@@ -173,7 +255,7 @@ class AirtableCreateRecordsBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
# The create_record API expects records in a specific format
|
||||
|
||||
data = await create_record(
|
||||
credentials,
|
||||
input_data.base_id,
|
||||
@@ -182,8 +264,22 @@ class AirtableCreateRecordsBlock(Block):
|
||||
typecast=input_data.typecast if input_data.typecast else None,
|
||||
return_fields_by_field_id=input_data.return_fields_by_field_id,
|
||||
)
|
||||
result_records = cast(list[dict], data.get("records", []))
|
||||
|
||||
yield "records", data.get("records", [])
|
||||
# Normalize output unless explicitly disabled
|
||||
if not input_data.skip_normalization and result_records:
|
||||
# Fetch table schema
|
||||
table_schema = await get_table_schema(
|
||||
credentials, input_data.base_id, input_data.table_id_or_name
|
||||
)
|
||||
|
||||
# Normalize the records
|
||||
normalized_data = await normalize_records(
|
||||
result_records, table_schema, include_field_metadata=False
|
||||
)
|
||||
result_records = normalized_data["records"]
|
||||
|
||||
yield "records", result_records
|
||||
details = data.get("details", None)
|
||||
if details:
|
||||
yield "details", details
|
||||
|
||||
@@ -1094,6 +1094,117 @@ class GmailGetThreadBlock(GmailBase):
|
||||
return thread
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
|
||||
Returns:
|
||||
tuple: (base64-encoded raw message, threadId)
|
||||
"""
|
||||
# Get parent message for reply context
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Build headers dictionary, preserving all values for duplicate headers
|
||||
headers = {}
|
||||
for h in parent.get("payload", {}).get("headers", []):
|
||||
name = h["name"].lower()
|
||||
value = h["value"]
|
||||
if name in headers:
|
||||
# For duplicate headers, keep the first occurrence (most relevant for reply context)
|
||||
continue
|
||||
headers[name] = value
|
||||
|
||||
# Determine recipients if not specified
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("to", "")])]
|
||||
recipients += [addr for _, addr in getaddresses([headers.get("cc", "")])]
|
||||
# Use dict.fromkeys() for O(n) deduplication while preserving order
|
||||
input_data.to = list(dict.fromkeys(filter(None, recipients)))
|
||||
else:
|
||||
# Check Reply-To header first, fall back to From header
|
||||
reply_to = headers.get("reply-to", "")
|
||||
from_addr = headers.get("from", "")
|
||||
sender = parseaddr(reply_to if reply_to else from_addr)[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
|
||||
# Set subject with Re: prefix if not already present
|
||||
if input_data.subject:
|
||||
subject = input_data.subject
|
||||
else:
|
||||
parent_subject = headers.get("subject", "").strip()
|
||||
# Only add "Re:" if not already present (case-insensitive check)
|
||||
if parent_subject.lower().startswith("re:"):
|
||||
subject = parent_subject
|
||||
else:
|
||||
subject = f"Re: {parent_subject}" if parent_subject else "Re:"
|
||||
|
||||
# Build references header for proper threading
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
# Create MIME message
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
|
||||
# Use the helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
# Encode message
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
return raw, input_data.threadId
|
||||
|
||||
|
||||
class GmailReplyBlock(GmailBase):
|
||||
"""
|
||||
Replies to Gmail threads with intelligent content type detection.
|
||||
@@ -1230,93 +1341,146 @@ class GmailReplyBlock(GmailBase):
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
parent = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=input_data.parentMessageId,
|
||||
format="metadata",
|
||||
metadataHeaders=[
|
||||
"Subject",
|
||||
"References",
|
||||
"Message-ID",
|
||||
"From",
|
||||
"To",
|
||||
"Cc",
|
||||
"Reply-To",
|
||||
],
|
||||
)
|
||||
.execute()
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in parent.get("payload", {}).get("headers", [])
|
||||
}
|
||||
if not (input_data.to or input_data.cc or input_data.bcc):
|
||||
if input_data.replyAll:
|
||||
recipients = [parseaddr(headers.get("from", ""))[1]]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("to", "")])
|
||||
]
|
||||
recipients += [
|
||||
addr for _, addr in getaddresses([headers.get("cc", "")])
|
||||
]
|
||||
dedup: list[str] = []
|
||||
for r in recipients:
|
||||
if r and r not in dedup:
|
||||
dedup.append(r)
|
||||
input_data.to = dedup
|
||||
else:
|
||||
sender = parseaddr(headers.get("reply-to", headers.get("from", "")))[1]
|
||||
input_data.to = [sender] if sender else []
|
||||
subject = input_data.subject or (f"Re: {headers.get('subject', '')}".strip())
|
||||
references = headers.get("references", "").split()
|
||||
if headers.get("message-id"):
|
||||
references.append(headers["message-id"])
|
||||
|
||||
msg = MIMEMultipart()
|
||||
if input_data.to:
|
||||
msg["To"] = ", ".join(input_data.to)
|
||||
if input_data.cc:
|
||||
msg["Cc"] = ", ".join(input_data.cc)
|
||||
if input_data.bcc:
|
||||
msg["Bcc"] = ", ".join(input_data.bcc)
|
||||
msg["Subject"] = subject
|
||||
if headers.get("message-id"):
|
||||
msg["In-Reply-To"] = headers["message-id"]
|
||||
if references:
|
||||
msg["References"] = " ".join(references)
|
||||
# Use the new helper function for consistent content type handling
|
||||
msg.attach(_make_mime_text(input_data.body, input_data.content_type))
|
||||
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
encoders.encode_base64(part)
|
||||
part.add_header(
|
||||
"Content-Disposition", f"attachment; filename={Path(abs_path).name}"
|
||||
)
|
||||
msg.attach(part)
|
||||
|
||||
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
||||
# Send the message
|
||||
return await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"threadId": input_data.threadId, "raw": raw})
|
||||
.send(userId="me", body={"threadId": thread_id, "raw": raw})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class GmailDraftReplyBlock(GmailBase):
|
||||
"""
|
||||
Creates draft replies to Gmail threads with intelligent content type detection.
|
||||
|
||||
Features:
|
||||
- Automatic HTML detection: Draft replies containing HTML tags are formatted as text/html
|
||||
- No hard-wrap for plain text: Plain text draft replies preserve natural line flow
|
||||
- Manual content type override: Use content_type parameter to force specific format
|
||||
- Reply-all functionality: Option to reply to all original recipients
|
||||
- Thread preservation: Maintains proper email threading with headers
|
||||
- Full Unicode/emoji support with UTF-8 encoding
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
[
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
]
|
||||
)
|
||||
threadId: str = SchemaField(description="Thread ID to reply in")
|
||||
parentMessageId: str = SchemaField(
|
||||
description="ID of the message being replied to"
|
||||
)
|
||||
to: list[str] = SchemaField(description="To recipients", default_factory=list)
|
||||
cc: list[str] = SchemaField(description="CC recipients", default_factory=list)
|
||||
bcc: list[str] = SchemaField(description="BCC recipients", default_factory=list)
|
||||
replyAll: bool = SchemaField(
|
||||
description="Reply to all original recipients", default=False
|
||||
)
|
||||
subject: str = SchemaField(description="Email subject", default="")
|
||||
body: str = SchemaField(description="Email body (plain text or HTML)")
|
||||
content_type: Optional[Literal["auto", "plain", "html"]] = SchemaField(
|
||||
description="Content type: 'auto' (default - detects HTML), 'plain', or 'html'",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
attachments: list[MediaFileType] = SchemaField(
|
||||
description="Files to attach", default_factory=list, advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
draftId: str = SchemaField(description="Created draft ID")
|
||||
messageId: str = SchemaField(description="Draft message ID")
|
||||
threadId: str = SchemaField(description="Thread ID")
|
||||
status: str = SchemaField(description="Draft creation status")
|
||||
error: str = SchemaField(description="Error message if any")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d7a9f3e2-8b4c-4d6f-9e1a-3c5b7f8d2a6e",
|
||||
description="Create draft replies to Gmail threads with automatic HTML detection and proper text formatting. Plain text draft replies maintain natural paragraph flow without 78-character line wrapping. HTML content is automatically detected and formatted correctly.",
|
||||
categories={BlockCategory.COMMUNICATION},
|
||||
input_schema=GmailDraftReplyBlock.Input,
|
||||
output_schema=GmailDraftReplyBlock.Output,
|
||||
disabled=not GOOGLE_OAUTH_IS_CONFIGURED,
|
||||
test_input={
|
||||
"threadId": "t1",
|
||||
"parentMessageId": "m1",
|
||||
"body": "Thanks for your message. I'll review and get back to you.",
|
||||
"replyAll": False,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("draftId", "draft1"),
|
||||
("messageId", "m2"),
|
||||
("threadId", "t1"),
|
||||
("status", "draft_created"),
|
||||
],
|
||||
test_mock={
|
||||
"_create_draft_reply": lambda *args, **kwargs: {
|
||||
"id": "draft1",
|
||||
"message": {"id": "m2", "threadId": "t1"},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
yield "threadId", draft["message"].get("threadId", input_data.threadId)
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
.create(
|
||||
userId="me",
|
||||
body={
|
||||
"message": {
|
||||
"threadId": thread_id,
|
||||
"raw": raw,
|
||||
}
|
||||
},
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return draft
|
||||
|
||||
|
||||
class GmailGetProfileBlock(GmailBase):
|
||||
class Input(BlockSchema):
|
||||
credentials: GoogleCredentialsInput = GoogleCredentialsField(
|
||||
|
||||
@@ -896,6 +896,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
"""Removes indentation up to and including `|` from a multi-line prompt."""
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
@@ -909,24 +910,25 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
f"{json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in input_data.expected_format.items()
|
||||
]
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
format_prompt = ",\n| ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
|Reply with pure JSON strictly following this JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. DO NOT include any additional text (e.g. markdown code block fences) outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -946,7 +948,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.debug(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
error_feedback_message = ""
|
||||
llm_model = input_data.model
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
@@ -970,8 +972,25 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
try:
|
||||
response_obj = json.loads(response_text)
|
||||
except JSONDecodeError as json_error:
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
indented_json_error = str(json_error).replace("\n", "\n|")
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your previous response could not be parsed as valid JSON:
|
||||
|
|
||||
|{indented_json_error}
|
||||
|
|
||||
|Please provide a valid JSON response that matches the expected format.
|
||||
"""
|
||||
)
|
||||
prompt.append(
|
||||
{"role": "user", "content": error_feedback_message}
|
||||
)
|
||||
continue
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
@@ -979,7 +998,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
response_error = "\n".join(
|
||||
validation_errors = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
@@ -991,7 +1010,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
if not validation_errors:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
@@ -1001,6 +1020,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
prompt.append({"role": "assistant", "content": response_text})
|
||||
error_feedback_message = trim_prompt(
|
||||
f"""
|
||||
|Your response did not match the expected format:
|
||||
|
|
||||
|{validation_errors}
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": error_feedback_message})
|
||||
else:
|
||||
self.merge_stats(
|
||||
NodeExecutionStats(
|
||||
@@ -1011,21 +1040,6 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
yield "response", {"response": response_text}
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
|
||||
retry_prompt = trim_prompt(
|
||||
f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "user", "content": retry_prompt})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error calling LLM: {e}")
|
||||
if (
|
||||
@@ -1038,9 +1052,12 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.debug(
|
||||
f"Reducing max_tokens to {input_data.max_tokens} for next attempt"
|
||||
)
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
# Don't add retry prompt for token limit errors,
|
||||
# just retry with lower maximum output tokens
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
error_feedback_message = f"Error calling LLM: {e}"
|
||||
|
||||
raise RuntimeError(error_feedback_message)
|
||||
|
||||
|
||||
class AITextGeneratorBlock(AIBlockBase):
|
||||
|
||||
@@ -1,57 +1,31 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_libs.api_key.key_manager import APIKeyManager
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from prisma.enums import APIKeyPermission, APIKeyStatus
|
||||
from prisma.errors import PrismaError
|
||||
from prisma.models import APIKey as PrismaAPIKey
|
||||
from prisma.types import (
|
||||
APIKeyCreateInput,
|
||||
APIKeyUpdateInput,
|
||||
APIKeyWhereInput,
|
||||
APIKeyWhereUniqueInput,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from prisma.types import APIKeyWhereUniqueInput
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.db import BaseDbModel
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
keysmith = APIKeySmith()
|
||||
|
||||
|
||||
# Some basic exceptions
|
||||
class APIKeyError(Exception):
|
||||
"""Base exception for API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyNotFoundError(APIKeyError):
|
||||
"""Raised when an API key is not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyPermissionError(APIKeyError):
|
||||
"""Raised when there are permission issues with API key operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidationError(APIKeyError):
|
||||
"""Raised when API key validation fails"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKey(BaseDbModel):
|
||||
class APIKeyInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
key: str
|
||||
status: APIKeyStatus = APIKeyStatus.ACTIVE
|
||||
permissions: List[APIKeyPermission]
|
||||
postfix: str
|
||||
head: str = Field(
|
||||
description=f"The first {APIKeySmith.HEAD_LENGTH} characters of the key"
|
||||
)
|
||||
tail: str = Field(
|
||||
description=f"The last {APIKeySmith.TAIL_LENGTH} characters of the key"
|
||||
)
|
||||
status: APIKeyStatus
|
||||
permissions: list[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime] = None
|
||||
revoked_at: Optional[datetime] = None
|
||||
@@ -60,266 +34,211 @@ class APIKey(BaseDbModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKey(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
key=api_key.key,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKey from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfo(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
head=api_key.head,
|
||||
tail=api_key.tail,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
|
||||
|
||||
class APIKeyWithoutHash(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
prefix: str
|
||||
postfix: str
|
||||
status: APIKeyStatus
|
||||
permissions: List[APIKeyPermission]
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
revoked_at: Optional[datetime]
|
||||
description: Optional[str]
|
||||
user_id: str
|
||||
class APIKeyInfoWithHash(APIKeyInfo):
|
||||
hash: str
|
||||
salt: str | None = None # None for legacy keys
|
||||
|
||||
def match(self, plaintext_key: str) -> bool:
|
||||
"""Returns whether the given key matches this API key object."""
|
||||
return keysmith.verify_key(plaintext_key, self.hash, self.salt)
|
||||
|
||||
@staticmethod
|
||||
def from_db(api_key: PrismaAPIKey):
|
||||
try:
|
||||
return APIKeyWithoutHash(
|
||||
id=api_key.id,
|
||||
name=api_key.name,
|
||||
prefix=api_key.prefix,
|
||||
postfix=api_key.postfix,
|
||||
status=APIKeyStatus(api_key.status),
|
||||
permissions=[APIKeyPermission(p) for p in api_key.permissions],
|
||||
created_at=api_key.createdAt,
|
||||
last_used_at=api_key.lastUsedAt,
|
||||
revoked_at=api_key.revokedAt,
|
||||
description=api_key.description,
|
||||
user_id=api_key.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}")
|
||||
raise APIKeyError(f"Failed to create API key object: {str(e)}")
|
||||
return APIKeyInfoWithHash(
|
||||
**APIKeyInfo.from_db(api_key).model_dump(),
|
||||
hash=api_key.hash,
|
||||
salt=api_key.salt,
|
||||
)
|
||||
|
||||
def without_hash(self) -> APIKeyInfo:
|
||||
return APIKeyInfo(**self.model_dump(exclude={"hash", "salt"}))
|
||||
|
||||
|
||||
async def generate_api_key(
|
||||
async def create_api_key(
|
||||
name: str,
|
||||
user_id: str,
|
||||
permissions: List[APIKeyPermission],
|
||||
permissions: list[APIKeyPermission],
|
||||
description: Optional[str] = None,
|
||||
) -> tuple[APIKeyWithoutHash, str]:
|
||||
) -> tuple[APIKeyInfo, str]:
|
||||
"""
|
||||
Generate a new API key and store it in the database.
|
||||
Returns the API key object (without hash) and the plain text key.
|
||||
"""
|
||||
try:
|
||||
api_manager = APIKeyManager()
|
||||
key = api_manager.generate_api_key()
|
||||
generated_key = keysmith.generate_key()
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().create(
|
||||
data=APIKeyCreateInput(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
prefix=key.prefix,
|
||||
postfix=key.postfix,
|
||||
key=key.hash,
|
||||
permissions=[p for p in permissions],
|
||||
description=description,
|
||||
userId=user_id,
|
||||
)
|
||||
)
|
||||
saved_key_obj = await PrismaAPIKey.prisma().create(
|
||||
data={
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": name,
|
||||
"head": generated_key.head,
|
||||
"tail": generated_key.tail,
|
||||
"hash": generated_key.hash,
|
||||
"salt": generated_key.salt,
|
||||
"permissions": [p for p in permissions],
|
||||
"description": description,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
api_key_without_hash = APIKeyWithoutHash.from_db(api_key)
|
||||
return api_key_without_hash, key.raw
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while generating API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to generate API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(saved_key_obj), generated_key.key
|
||||
|
||||
|
||||
async def validate_api_key(plain_text_key: str) -> Optional[APIKey]:
|
||||
async def get_active_api_keys_by_head(head: str) -> list[APIKeyInfoWithHash]:
|
||||
results = await PrismaAPIKey.prisma().find_many(
|
||||
where={"head": head, "status": APIKeyStatus.ACTIVE}
|
||||
)
|
||||
return [APIKeyInfoWithHash.from_db(key) for key in results]
|
||||
|
||||
|
||||
async def validate_api_key(plaintext_key: str) -> Optional[APIKeyInfo]:
|
||||
"""
|
||||
Validate an API key and return the API key object if valid.
|
||||
Validate an API key and return the API key object if valid and active.
|
||||
"""
|
||||
try:
|
||||
if not plain_text_key.startswith(APIKeyManager.PREFIX):
|
||||
if not plaintext_key.startswith(APIKeySmith.PREFIX):
|
||||
logger.warning("Invalid API key format")
|
||||
return None
|
||||
|
||||
prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH]
|
||||
api_manager = APIKeyManager()
|
||||
head = plaintext_key[: APIKeySmith.HEAD_LENGTH]
|
||||
potential_matches = await get_active_api_keys_by_head(head)
|
||||
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE))
|
||||
matched_api_key = next(
|
||||
(pm for pm in potential_matches if pm.match(plaintext_key)),
|
||||
None,
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"No active API key found with prefix {prefix}")
|
||||
if not matched_api_key:
|
||||
# API key not found or invalid
|
||||
return None
|
||||
|
||||
is_valid = api_manager.verify_api_key(plain_text_key, api_key.key)
|
||||
if not is_valid:
|
||||
logger.warning("API key verification failed")
|
||||
return None
|
||||
|
||||
return APIKey.from_db(api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating API key: {str(e)}")
|
||||
raise APIKeyValidationError(f"Failed to validate API key: {str(e)}")
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to revoke this API key."
|
||||
# Migrate legacy keys to secure format on successful validation
|
||||
if matched_api_key.salt is None:
|
||||
matched_api_key = await _migrate_key_to_secure_hash(
|
||||
plaintext_key, matched_api_key
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(
|
||||
status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
return matched_api_key.without_hash()
|
||||
except Exception as e:
|
||||
logger.error(f"Error while validating API key: {e}")
|
||||
raise RuntimeError("Failed to validate API key") from e
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
|
||||
async def _migrate_key_to_secure_hash(
|
||||
plaintext_key: str, key_obj: APIKeyInfoWithHash
|
||||
) -> APIKeyInfoWithHash:
|
||||
"""Replace the SHA256 hash of a legacy API key with a salted Scrypt hash."""
|
||||
try:
|
||||
new_hash, new_salt = keysmith.hash_key(plaintext_key)
|
||||
await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_obj.id}, data={"hash": new_hash, "salt": new_salt}
|
||||
)
|
||||
logger.info(f"Migrated legacy API key #{key_obj.id} to secure format")
|
||||
# Update the API key object with new values for return
|
||||
key_obj.hash = new_hash
|
||||
key_obj.salt = new_salt
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate legacy API key #{key_obj.id}: {e}")
|
||||
|
||||
return key_obj
|
||||
|
||||
|
||||
async def revoke_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to revoke this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={
|
||||
"status": APIKeyStatus.REVOKED,
|
||||
"revokedAt": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to revoke.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> list[APIKeyInfo]:
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where={"userId": user_id}, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyInfo.from_db(key) for key in api_keys]
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> APIKeyInfo:
|
||||
selector: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where=selector)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to suspend this API key.")
|
||||
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=selector, data={"status": APIKeyStatus.SUSPENDED}
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to suspend.")
|
||||
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
|
||||
def has_permission(api_key: APIKeyInfo, required_permission: APIKeyPermission) -> bool:
|
||||
return required_permission in api_key.permissions
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyInfo]:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where={"id": key_id, "userId": user_id}
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while revoking API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to revoke API key: {str(e)}")
|
||||
|
||||
|
||||
async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]:
|
||||
try:
|
||||
where_clause: APIKeyWhereInput = {"userId": user_id}
|
||||
|
||||
api_keys = await PrismaAPIKey.prisma().find_many(
|
||||
where=where_clause, order={"createdAt": "desc"}
|
||||
)
|
||||
|
||||
return [APIKeyWithoutHash.from_db(key) for key in api_keys]
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while listing API keys: {str(e)}")
|
||||
raise APIKeyError(f"Failed to list API keys: {str(e)}")
|
||||
|
||||
|
||||
async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if not api_key:
|
||||
raise APIKeyNotFoundError(f"API key with id {key_id} not found")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to suspend this API key."
|
||||
)
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED),
|
||||
)
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while suspending API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to suspend API key: {str(e)}")
|
||||
|
||||
|
||||
def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool:
|
||||
try:
|
||||
return required_permission in api_key.permissions
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking API key permissions: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]:
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_first(
|
||||
where=APIKeyWhereInput(id=key_id, userId=user_id)
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return APIKeyWithoutHash.from_db(api_key)
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while getting API key: {str(e)}")
|
||||
raise APIKeyError(f"Failed to get API key: {str(e)}")
|
||||
return APIKeyInfo.from_db(api_key)
|
||||
|
||||
|
||||
async def update_api_key_permissions(
|
||||
key_id: str, user_id: str, permissions: List[APIKeyPermission]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
key_id: str, user_id: str, permissions: list[APIKeyPermission]
|
||||
) -> APIKeyInfo:
|
||||
"""
|
||||
Update the permissions of an API key.
|
||||
"""
|
||||
try:
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id})
|
||||
|
||||
if api_key is None:
|
||||
raise APIKeyNotFoundError("No such API key found.")
|
||||
if api_key is None:
|
||||
raise NotFoundError("No such API key found.")
|
||||
|
||||
if api_key.userId != user_id:
|
||||
raise APIKeyPermissionError(
|
||||
"You do not have permission to update this API key."
|
||||
)
|
||||
if api_key.userId != user_id:
|
||||
raise NotAuthorizedError("You do not have permission to update this API key.")
|
||||
|
||||
where_clause: APIKeyWhereUniqueInput = {"id": key_id}
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where=where_clause,
|
||||
data=APIKeyUpdateInput(permissions=permissions),
|
||||
)
|
||||
updated_api_key = await PrismaAPIKey.prisma().update(
|
||||
where={"id": key_id},
|
||||
data={"permissions": permissions},
|
||||
)
|
||||
if not updated_api_key:
|
||||
raise NotFoundError(f"API key #{key_id} vanished while trying to update.")
|
||||
|
||||
if updated_api_key:
|
||||
return APIKeyWithoutHash.from_db(updated_api_key)
|
||||
return None
|
||||
except (APIKeyNotFoundError, APIKeyPermissionError) as e:
|
||||
raise e
|
||||
except PrismaError as e:
|
||||
logger.error(f"Database error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating API key permissions: {str(e)}")
|
||||
raise APIKeyError(f"Failed to update API key permissions: {str(e)}")
|
||||
return APIKeyInfo.from_db(updated_api_key)
|
||||
|
||||
@@ -91,6 +91,45 @@ class BlockCategory(Enum):
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
class BlockInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
inputSchema: dict[str, Any]
|
||||
outputSchema: dict[str, Any]
|
||||
costs: list[BlockCost]
|
||||
description: str
|
||||
categories: list[dict[str, str]]
|
||||
contributors: list[dict[str, Any]]
|
||||
staticOutput: bool
|
||||
uiType: str
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
|
||||
@@ -454,6 +493,24 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.block_type.value,
|
||||
}
|
||||
|
||||
def get_info(self) -> BlockInfo:
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
return BlockInfo(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
inputSchema=self.input_schema.jsonschema(),
|
||||
outputSchema=self.output_schema.jsonschema(),
|
||||
costs=get_block_cost(self),
|
||||
description=self.description,
|
||||
categories=[category.dict() for category in self.categories],
|
||||
contributors=[
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
staticOutput=self.static_output,
|
||||
uiType=self.block_type.value,
|
||||
)
|
||||
|
||||
async def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
|
||||
@@ -29,8 +29,7 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.data.block import Block
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockInput
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import stripe
|
||||
from prisma import Json
|
||||
@@ -23,7 +23,6 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.model import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -41,6 +40,9 @@ from backend.util.models import Pagination
|
||||
from backend.util.retry import func_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockCost
|
||||
|
||||
settings = Settings()
|
||||
stripe.api_key = settings.secrets.stripe_api_key
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -997,10 +999,14 @@ def get_user_credit_model() -> UserCreditBase:
|
||||
return UserCredit()
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
def get_block_costs() -> dict[str, list["BlockCost"]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
|
||||
|
||||
def get_block_cost(block: "Block") -> list["BlockCost"]:
|
||||
return BLOCK_COSTS.get(block.__class__, [])
|
||||
|
||||
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
|
||||
@@ -92,6 +92,31 @@ ExecutionStatus = AgentExecutionStatus
|
||||
NodeInputMask = Mapping[str, JsonValue]
|
||||
NodesInputMasks = Mapping[str, NodeInputMask]
|
||||
|
||||
# dest: source
|
||||
VALID_STATUS_TRANSITIONS = {
|
||||
ExecutionStatus.QUEUED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
ExecutionStatus.RUNNING: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.TERMINATED, # For resuming halted execution
|
||||
],
|
||||
ExecutionStatus.COMPLETED: [
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.FAILED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
ExecutionStatus.TERMINATED: [
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class GraphExecutionMeta(BaseDbModel):
|
||||
id: str # type: ignore # Override base class to make this required
|
||||
@@ -105,6 +130,8 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
is_shared: bool = False
|
||||
share_token: Optional[str] = None
|
||||
|
||||
class Stats(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
@@ -221,6 +248,8 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
if stats
|
||||
else None
|
||||
),
|
||||
is_shared=_graph_exec.isShared,
|
||||
share_token=_graph_exec.shareToken,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,7 +609,7 @@ async def create_graph_execution(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"inputs": SafeJson(inputs),
|
||||
"credentialInputs": (
|
||||
SafeJson(credential_inputs) if credential_inputs else Json({})
|
||||
@@ -727,6 +756,11 @@ async def update_graph_execution_stats(
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
if not status and not stats:
|
||||
raise ValueError(
|
||||
f"Must provide either status or stats to update for execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
@@ -738,20 +772,25 @@ async def update_graph_execution_stats(
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
# Terminated graph can be resumed.
|
||||
{"executionStatus": ExecutionStatus.TERMINATED},
|
||||
],
|
||||
},
|
||||
where_clause: AgentGraphExecutionWhereInput = {"id": graph_exec_id}
|
||||
|
||||
if status:
|
||||
if allowed_from := VALID_STATUS_TRANSITIONS.get(status, []):
|
||||
# Add OR clause to check if current status is one of the allowed source statuses
|
||||
where_clause["AND"] = [
|
||||
{"id": graph_exec_id},
|
||||
{"OR": [{"executionStatus": s} for s in allowed_from]},
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Status {status} cannot be set via update for execution {graph_exec_id}. "
|
||||
f"This status can only be set at creation or is not a valid target status."
|
||||
)
|
||||
|
||||
await AgentGraphExecution.prisma().update_many(
|
||||
where=where_clause,
|
||||
data=update_data,
|
||||
)
|
||||
if updated_count == 0:
|
||||
return None
|
||||
|
||||
graph_exec = await AgentGraphExecution.prisma().find_unique_or_raise(
|
||||
where={"id": graph_exec_id},
|
||||
@@ -759,6 +798,7 @@ async def update_graph_execution_stats(
|
||||
[*get_io_block_ids(), *get_webhook_block_ids()]
|
||||
),
|
||||
)
|
||||
|
||||
return GraphExecution.from_db(graph_exec)
|
||||
|
||||
|
||||
@@ -985,6 +1025,18 @@ class NodeExecutionEvent(NodeExecutionResult):
|
||||
)
|
||||
|
||||
|
||||
class SharedExecutionResponse(BaseModel):
|
||||
"""Public-safe response for shared executions"""
|
||||
|
||||
id: str
|
||||
graph_name: str
|
||||
graph_description: Optional[str]
|
||||
status: ExecutionStatus
|
||||
created_at: datetime
|
||||
outputs: CompletedBlockOutput # Only the final outputs, no intermediate data
|
||||
# Deliberately exclude: user_id, inputs, credentials, node details
|
||||
|
||||
|
||||
ExecutionEvent = Annotated[
|
||||
GraphExecutionEvent | NodeExecutionEvent, Field(discriminator="event_type")
|
||||
]
|
||||
@@ -1162,3 +1214,98 @@ async def get_block_error_stats(
|
||||
)
|
||||
for row in result
|
||||
]
|
||||
|
||||
|
||||
async def update_graph_execution_share_status(
|
||||
execution_id: str,
|
||||
user_id: str,
|
||||
is_shared: bool,
|
||||
share_token: str | None,
|
||||
shared_at: datetime | None,
|
||||
) -> None:
|
||||
"""Update the sharing status of a graph execution."""
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": execution_id},
|
||||
data={
|
||||
"isShared": is_shared,
|
||||
"shareToken": share_token,
|
||||
"sharedAt": shared_at,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_execution_by_share_token(
|
||||
share_token: str,
|
||||
) -> SharedExecutionResponse | None:
|
||||
"""Get a shared execution with limited public-safe data."""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={
|
||||
"shareToken": share_token,
|
||||
"isShared": True,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={
|
||||
"AgentGraph": True,
|
||||
"NodeExecutions": {
|
||||
"include": {
|
||||
"Output": True,
|
||||
"Node": {
|
||||
"include": {
|
||||
"AgentBlock": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# Extract outputs from OUTPUT blocks only (consistent with GraphExecution.from_db)
|
||||
outputs: CompletedBlockOutput = defaultdict(list)
|
||||
if execution.NodeExecutions:
|
||||
for node_exec in execution.NodeExecutions:
|
||||
if node_exec.Node and node_exec.Node.agentBlockId:
|
||||
# Get the block definition to check its type
|
||||
block = get_block(node_exec.Node.agentBlockId)
|
||||
|
||||
if block and block.block_type == BlockType.OUTPUT:
|
||||
# For OUTPUT blocks, the data is stored in executionData or Input
|
||||
# The executionData contains the structured input with 'name' and 'value' fields
|
||||
if hasattr(node_exec, "executionData") and node_exec.executionData:
|
||||
exec_data = type_utils.convert(
|
||||
node_exec.executionData, dict[str, Any]
|
||||
)
|
||||
if "name" in exec_data:
|
||||
name = exec_data["name"]
|
||||
value = exec_data.get("value")
|
||||
outputs[name].append(value)
|
||||
elif node_exec.Input:
|
||||
# Build input_data from Input relation
|
||||
input_data = {}
|
||||
for data in node_exec.Input:
|
||||
if data.name and data.data is not None:
|
||||
input_data[data.name] = type_utils.convert(
|
||||
data.data, JsonValue
|
||||
)
|
||||
|
||||
if "name" in input_data:
|
||||
name = input_data["name"]
|
||||
value = input_data.get("value")
|
||||
outputs[name].append(value)
|
||||
|
||||
return SharedExecutionResponse(
|
||||
id=execution.id,
|
||||
graph_name=(
|
||||
execution.AgentGraph.name
|
||||
if (execution.AgentGraph and execution.AgentGraph.name)
|
||||
else "Untitled Agent"
|
||||
),
|
||||
graph_description=(
|
||||
execution.AgentGraph.description if execution.AgentGraph else None
|
||||
),
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
created_at=execution.createdAt,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from prisma.enums import SubmissionStatus
|
||||
@@ -160,6 +161,8 @@ class BaseGraph(BaseDbModel):
|
||||
is_active: bool = True
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
forked_from_id: str | None = None
|
||||
@@ -380,6 +383,8 @@ class GraphModel(Graph):
|
||||
user_id: str
|
||||
nodes: list[NodeModel] = [] # type: ignore
|
||||
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[NodeModel]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
@@ -392,6 +397,10 @@ class GraphModel(Graph):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||
return cast(NodeModel, super().webhook_input_node)
|
||||
|
||||
def meta(self) -> "GraphMeta":
|
||||
"""
|
||||
Returns a GraphMeta object with metadata about the graph.
|
||||
@@ -693,9 +702,12 @@ class GraphModel(Graph):
|
||||
version=graph.version,
|
||||
forked_from_id=graph.forkedFromId,
|
||||
forked_from_version=graph.forkedFromVersion,
|
||||
created_at=graph.createdAt,
|
||||
is_active=graph.isActive,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
instructions=graph.instructions,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
|
||||
links=list(
|
||||
{
|
||||
@@ -1083,6 +1095,7 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
version=graph.version,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
recommendedScheduleCron=graph.recommended_schedule_cron,
|
||||
isActive=graph.is_active,
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
@@ -1141,6 +1154,7 @@ def make_graph_model(creatable_graph: Graph, user_id: str) -> GraphModel:
|
||||
return GraphModel(
|
||||
**creatable_graph.model_dump(exclude={"nodes"}),
|
||||
user_id=user_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
nodes=[
|
||||
NodeModel(
|
||||
**creatable_node.model_dump(),
|
||||
|
||||
@@ -13,7 +13,7 @@ load_dotenv()
|
||||
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ async def generate_activity_status_for_execution(
|
||||
# Check if we have OpenAI API key
|
||||
try:
|
||||
settings = Settings()
|
||||
if not settings.secrets.openai_api_key:
|
||||
if not settings.secrets.openai_internal_api_key:
|
||||
logger.debug(
|
||||
"OpenAI API key not configured, skipping activity status generation"
|
||||
)
|
||||
@@ -187,7 +187,7 @@ async def generate_activity_status_for_execution(
|
||||
credentials = APIKeyCredentials(
|
||||
id="openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(settings.secrets.openai_api_key),
|
||||
api_key=SecretStr(settings.secrets.openai_internal_api_key),
|
||||
title="System OpenAI",
|
||||
)
|
||||
|
||||
|
||||
@@ -468,7 +468,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = (
|
||||
"I analyzed your data and provided the requested insights."
|
||||
)
|
||||
@@ -520,7 +520,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = ""
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = ""
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -546,7 +546,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
"backend.executor.activity_status_generator.is_feature_enabled",
|
||||
return_value=True,
|
||||
):
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
graph_exec_id="test_exec",
|
||||
@@ -581,7 +581,7 @@ class TestGenerateActivityStatusForExecution:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
mock_llm.return_value = "Agent completed execution."
|
||||
|
||||
result = await generate_activity_status_for_execution(
|
||||
@@ -633,7 +633,7 @@ class TestIntegration:
|
||||
):
|
||||
|
||||
mock_get_block.side_effect = lambda block_id: mock_blocks.get(block_id)
|
||||
mock_settings.return_value.secrets.openai_api_key = "test_key"
|
||||
mock_settings.return_value.secrets.openai_internal_api_key = "test_key"
|
||||
|
||||
mock_response = LLMResponse(
|
||||
raw_response={},
|
||||
|
||||
@@ -605,7 +605,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
return
|
||||
|
||||
if exec_meta.status == ExecutionStatus.QUEUED:
|
||||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||||
exec_meta.status = ExecutionStatus.RUNNING
|
||||
send_execution_update(
|
||||
|
||||
@@ -191,15 +191,22 @@ class GraphExecutionJobInfo(GraphExecutionJobArgs):
|
||||
id: str
|
||||
name: str
|
||||
next_run_time: str
|
||||
timezone: str = Field(default="UTC", description="Timezone used for scheduling")
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
job_args: GraphExecutionJobArgs, job_obj: JobObj
|
||||
) -> "GraphExecutionJobInfo":
|
||||
# Extract timezone from the trigger if it's a CronTrigger
|
||||
timezone_str = "UTC"
|
||||
if hasattr(job_obj.trigger, "timezone"):
|
||||
timezone_str = str(job_obj.trigger.timezone)
|
||||
|
||||
return GraphExecutionJobInfo(
|
||||
id=job_obj.id,
|
||||
name=job_obj.name,
|
||||
next_run_time=job_obj.next_run_time.isoformat(),
|
||||
timezone=timezone_str,
|
||||
**job_args.model_dump(),
|
||||
)
|
||||
|
||||
@@ -395,6 +402,7 @@ class Scheduler(AppService):
|
||||
input_data: BlockInput,
|
||||
input_credentials: dict[str, CredentialsMetaInput],
|
||||
name: Optional[str] = None,
|
||||
user_timezone: str | None = None,
|
||||
) -> GraphExecutionJobInfo:
|
||||
# Validate the graph before scheduling to prevent runtime failures
|
||||
# We don't need the return value, just want the validation to run
|
||||
@@ -408,7 +416,18 @@ class Scheduler(AppService):
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Scheduling job for user {user_id} in UTC (cron: {cron})")
|
||||
# Use provided timezone or default to UTC
|
||||
# Note: Timezone should be passed from the client to avoid database lookups
|
||||
if not user_timezone:
|
||||
user_timezone = "UTC"
|
||||
logger.warning(
|
||||
f"No timezone provided for user {user_id}, using UTC for scheduling. "
|
||||
f"Client should pass user's timezone for correct scheduling."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling job for user {user_id} with timezone {user_timezone} (cron: {cron})"
|
||||
)
|
||||
|
||||
job_args = GraphExecutionJobArgs(
|
||||
user_id=user_id,
|
||||
@@ -422,12 +441,12 @@ class Scheduler(AppService):
|
||||
execute_graph,
|
||||
kwargs=job_args.model_dump(),
|
||||
name=name,
|
||||
trigger=CronTrigger.from_crontab(cron, timezone="UTC"),
|
||||
trigger=CronTrigger.from_crontab(cron, timezone=user_timezone),
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Added job {job.id} with cron schedule '{cron}' in UTC, input data: {input_data}"
|
||||
f"Added job {job.id} with cron schedule '{cron}' in timezone {user_timezone}, input data: {input_data}"
|
||||
)
|
||||
return GraphExecutionJobInfo.from_db(job_args, job)
|
||||
|
||||
|
||||
@@ -10,9 +10,15 @@ from pydantic import BaseModel, JsonValue, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import Block, BlockInput, BlockOutputEntry, BlockType, get_block
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCostType,
|
||||
BlockInput,
|
||||
BlockOutputEntry,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.db import prisma
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
@@ -908,29 +914,30 @@ async def add_graph_execution(
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
# Fetch user context for the graph execution
|
||||
user_context = await get_user_context(user_id)
|
||||
|
||||
queue = await get_async_execution_queue()
|
||||
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
||||
user_context, compiled_nodes_input_masks
|
||||
user_context=await get_user_context(user_id),
|
||||
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created graph execution #{graph_exec.id} for graph "
|
||||
f"#{graph_id} with {len(starting_nodes_input)} starting nodes. "
|
||||
f"Now publishing to execution queue."
|
||||
)
|
||||
|
||||
await queue.publish_message(
|
||||
exec_queue = await get_async_execution_queue()
|
||||
await exec_queue.publish_message(
|
||||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||||
message=graph_exec_entry.model_dump_json(),
|
||||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
||||
|
||||
bus = get_async_execution_event_bus()
|
||||
await bus.publish(graph_exec)
|
||||
graph_exec.status = ExecutionStatus.QUEUED
|
||||
await edb.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.id,
|
||||
status=graph_exec.status,
|
||||
)
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except BaseException as e:
|
||||
|
||||
@@ -316,6 +316,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Mock the graph execution object
|
||||
mock_graph_exec = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock user context
|
||||
@@ -346,6 +347,10 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
)
|
||||
mock_prisma.is_connected.return_value = True
|
||||
mock_edb.create_graph_execution = mocker.AsyncMock(return_value=mock_graph_exec)
|
||||
mock_edb.update_graph_execution_stats = mocker.AsyncMock(
|
||||
return_value=mock_graph_exec
|
||||
)
|
||||
mock_edb.update_node_execution_status_batch = mocker.AsyncMock()
|
||||
mock_get_user_context.return_value = mock_user_context
|
||||
mock_get_queue.return_value = mock_queue
|
||||
mock_get_event_bus.return_value = mock_event_bus
|
||||
|
||||
@@ -7,10 +7,9 @@ from backend.data.graph import set_node_webhook
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
from .utils import setup_webhook_for_block
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import BaseGraph, GraphModel, Node, NodeModel
|
||||
from backend.data.graph import BaseGraph, GraphModel, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
@@ -43,32 +42,19 @@ async def _on_graph_activate(graph: "BaseGraph", user_id: str) -> "BaseGraph": .
|
||||
|
||||
async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
||||
get_credentials = credentials_manager.cached_getter(user_id)
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
# Prevent saving graph with non-existent credentials
|
||||
if (
|
||||
creds_meta := new_node.input_default.get(creds_field_name)
|
||||
) and not await get_credentials(creds_meta["id"]):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := new_node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
user_id, graph.id, new_node, credentials=node_credentials
|
||||
)
|
||||
updated_nodes.append(updated_node)
|
||||
|
||||
graph.nodes = updated_nodes
|
||||
return graph
|
||||
|
||||
|
||||
@@ -85,20 +71,14 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
block_input_schema = cast(BlockSchema, node.block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
for creds_field_name in block_input_schema.get_credentials_fields().keys():
|
||||
if (creds_meta := node.input_default.get(creds_field_name)) and not (
|
||||
node_credentials := await get_credentials(creds_meta["id"])
|
||||
):
|
||||
logger.warning(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := await get_credentials(creds_meta["id"]))
|
||||
):
|
||||
logger.error(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(
|
||||
user_id, node, credentials=node_credentials
|
||||
@@ -109,32 +89,6 @@ async def on_graph_deactivate(graph: "GraphModel", user_id: str):
|
||||
return graph
|
||||
|
||||
|
||||
async def on_node_activate(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
node: "Node",
|
||||
*,
|
||||
credentials: Optional["Credentials"] = None,
|
||||
) -> "Node":
|
||||
"""Hook to be called when the node is activated/created"""
|
||||
|
||||
if node.block.webhook_config:
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=node.block,
|
||||
trigger_config=node.input_default,
|
||||
for_graph_id=graph_id,
|
||||
)
|
||||
if new_webhook:
|
||||
node = await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Node #{node.id} does not have everything for a webhook: {feedback}"
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def on_node_deactivate(
|
||||
user_id: str,
|
||||
node: "NodeModel",
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, cast
|
||||
from pydantic import JsonValue
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
from . import get_webhook_manager, supports_webhooks
|
||||
@@ -13,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from backend.data.block import Block, BlockSchema
|
||||
from backend.data.integrations import Webhook
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app_config = Config()
|
||||
@@ -20,7 +20,7 @@ credentials_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
def webhook_ingress_url(provider_name: "ProviderName", webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
@@ -144,3 +144,62 @@ async def setup_webhook_for_block(
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {webhook}")
|
||||
return webhook, None
|
||||
|
||||
|
||||
async def migrate_legacy_triggered_graphs():
|
||||
from prisma.models import AgentGraph
|
||||
|
||||
from backend.data.graph import AGENT_GRAPH_INCLUDE, GraphModel, set_node_webhook
|
||||
from backend.data.model import is_credentials_field_name
|
||||
from backend.server.v2.library.db import create_preset
|
||||
from backend.server.v2.library.model import LibraryAgentPresetCreatable
|
||||
|
||||
triggered_graphs = [
|
||||
GraphModel.from_db(_graph)
|
||||
for _graph in await AgentGraph.prisma().find_many(
|
||||
where={
|
||||
"isActive": True,
|
||||
"Nodes": {"some": {"NOT": [{"webhookId": None}]}},
|
||||
},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
]
|
||||
|
||||
n_migrated_webhooks = 0
|
||||
|
||||
for graph in triggered_graphs:
|
||||
if not ((trigger_node := graph.webhook_input_node) and trigger_node.webhook_id):
|
||||
continue
|
||||
|
||||
# Use trigger node's inputs for the preset
|
||||
preset_credentials = {
|
||||
field_name: creds_meta
|
||||
for field_name, creds_meta in trigger_node.input_default.items()
|
||||
if is_credentials_field_name(field_name)
|
||||
}
|
||||
preset_inputs = {
|
||||
field_name: value
|
||||
for field_name, value in trigger_node.input_default.items()
|
||||
if not is_credentials_field_name(field_name)
|
||||
}
|
||||
|
||||
# Create a triggered preset for the graph
|
||||
await create_preset(
|
||||
graph.user_id,
|
||||
LibraryAgentPresetCreatable(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
inputs=preset_inputs,
|
||||
credentials=preset_credentials,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
webhook_id=trigger_node.webhook_id,
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
# Detach webhook from the graph node
|
||||
await set_node_webhook(trigger_node.id, None)
|
||||
|
||||
n_migrated_webhooks += 1
|
||||
|
||||
logger.info(f"Migrated {n_migrated_webhooks} node triggers to triggered presets")
|
||||
|
||||
287
autogpt_platform/backend/backend/monitoring/instrumentation.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Prometheus instrumentation for FastAPI services.
|
||||
|
||||
This module provides centralized metrics collection and instrumentation
|
||||
for all FastAPI services in the AutoGPT platform.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info
|
||||
from prometheus_fastapi_instrumentator import Instrumentator, metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom business metrics with controlled cardinality
|
||||
GRAPH_EXECUTIONS = Counter(
|
||||
"autogpt_graph_executions_total",
|
||||
"Total number of graph executions",
|
||||
labelnames=[
|
||||
"status"
|
||||
], # Removed graph_id and user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
GRAPH_EXECUTIONS_BY_USER = Counter(
|
||||
"autogpt_graph_executions_by_user_total",
|
||||
"Total number of graph executions by user (sampled)",
|
||||
labelnames=["status"], # Only status, user_id tracked separately when needed
|
||||
)
|
||||
|
||||
BLOCK_EXECUTIONS = Counter(
|
||||
"autogpt_block_executions_total",
|
||||
"Total number of block executions",
|
||||
labelnames=["block_type", "status"], # block_type is bounded
|
||||
)
|
||||
|
||||
BLOCK_DURATION = Histogram(
|
||||
"autogpt_block_duration_seconds",
|
||||
"Duration of block executions in seconds",
|
||||
labelnames=["block_type"],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
|
||||
WEBSOCKET_CONNECTIONS = Gauge(
|
||||
"autogpt_websocket_connections_total",
|
||||
"Total number of active WebSocket connections",
|
||||
# Removed user_id label - track total only to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SCHEDULER_JOBS = Gauge(
|
||||
"autogpt_scheduler_jobs",
|
||||
"Current number of scheduled jobs",
|
||||
labelnames=["job_type", "status"],
|
||||
)
|
||||
|
||||
DATABASE_QUERIES = Histogram(
|
||||
"autogpt_database_query_duration_seconds",
|
||||
"Duration of database queries in seconds",
|
||||
labelnames=["operation", "table"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5],
|
||||
)
|
||||
|
||||
RABBITMQ_MESSAGES = Counter(
|
||||
"autogpt_rabbitmq_messages_total",
|
||||
"Total number of RabbitMQ messages",
|
||||
labelnames=["queue", "status"],
|
||||
)
|
||||
|
||||
AUTHENTICATION_ATTEMPTS = Counter(
|
||||
"autogpt_auth_attempts_total",
|
||||
"Total number of authentication attempts",
|
||||
labelnames=["method", "status"],
|
||||
)
|
||||
|
||||
API_KEY_USAGE = Counter(
|
||||
"autogpt_api_key_usage_total",
|
||||
"API key usage by provider",
|
||||
labelnames=["provider", "block_type", "status"],
|
||||
)
|
||||
|
||||
# Function/operation level metrics with controlled cardinality
|
||||
GRAPH_OPERATIONS = Counter(
|
||||
"autogpt_graph_operations_total",
|
||||
"Graph operations by type",
|
||||
labelnames=["operation", "status"], # create, update, delete, execute, etc.
|
||||
)
|
||||
|
||||
USER_OPERATIONS = Counter(
|
||||
"autogpt_user_operations_total",
|
||||
"User operations by type",
|
||||
labelnames=["operation", "status"], # login, register, update_profile, etc.
|
||||
)
|
||||
|
||||
RATE_LIMIT_HITS = Counter(
|
||||
"autogpt_rate_limit_hits_total",
|
||||
"Number of rate limit hits",
|
||||
labelnames=["endpoint"], # Removed user_id to prevent cardinality explosion
|
||||
)
|
||||
|
||||
SERVICE_INFO = Info(
|
||||
"autogpt_service",
|
||||
"Service information",
|
||||
)
|
||||
|
||||
|
||||
def instrument_fastapi(
|
||||
app: FastAPI,
|
||||
service_name: str,
|
||||
expose_endpoint: bool = True,
|
||||
endpoint: str = "/metrics",
|
||||
include_in_schema: bool = False,
|
||||
excluded_handlers: Optional[list] = None,
|
||||
) -> Instrumentator:
|
||||
"""
|
||||
Instrument a FastAPI application with Prometheus metrics.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
service_name: Name of the service for metrics labeling
|
||||
expose_endpoint: Whether to expose /metrics endpoint
|
||||
endpoint: Path for metrics endpoint
|
||||
include_in_schema: Whether to include metrics endpoint in OpenAPI schema
|
||||
excluded_handlers: List of paths to exclude from metrics
|
||||
|
||||
Returns:
|
||||
Configured Instrumentator instance
|
||||
"""
|
||||
|
||||
# Set service info
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
service_version = version("autogpt-platform-backend")
|
||||
except Exception:
|
||||
service_version = "unknown"
|
||||
|
||||
SERVICE_INFO.info(
|
||||
{
|
||||
"service": service_name,
|
||||
"version": service_version,
|
||||
}
|
||||
)
|
||||
|
||||
# Create instrumentator with default metrics
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=True,
|
||||
should_ignore_untemplated=True,
|
||||
should_respect_env_var=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
excluded_handlers=excluded_handlers or ["/health", "/readiness"],
|
||||
env_var_name="ENABLE_METRICS",
|
||||
inprogress_name="autogpt_http_requests_inprogress",
|
||||
inprogress_labels=True,
|
||||
)
|
||||
|
||||
# Add default HTTP metrics
|
||||
instrumentator.add(
|
||||
metrics.default(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add request size metrics
|
||||
instrumentator.add(
|
||||
metrics.request_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add response size metrics
|
||||
instrumentator.add(
|
||||
metrics.response_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Add latency metrics with custom buckets for better granularity
|
||||
instrumentator.add(
|
||||
metrics.latency(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60],
|
||||
)
|
||||
)
|
||||
|
||||
# Add combined metrics (requests by method and status)
|
||||
instrumentator.add(
|
||||
metrics.combined_size(
|
||||
metric_namespace="autogpt",
|
||||
metric_subsystem=service_name.replace("-", "_"),
|
||||
)
|
||||
)
|
||||
|
||||
# Instrument the app
|
||||
instrumentator.instrument(app)
|
||||
|
||||
# Expose metrics endpoint if requested
|
||||
if expose_endpoint:
|
||||
instrumentator.expose(
|
||||
app,
|
||||
endpoint=endpoint,
|
||||
include_in_schema=include_in_schema,
|
||||
tags=["monitoring"] if include_in_schema else None,
|
||||
)
|
||||
logger.info(f"Metrics endpoint exposed at {endpoint} for {service_name}")
|
||||
|
||||
return instrumentator
|
||||
|
||||
|
||||
def record_graph_execution(graph_id: str, status: str, user_id: str):
|
||||
"""Record a graph execution event.
|
||||
|
||||
Args:
|
||||
graph_id: Graph identifier (kept for future sampling/debugging)
|
||||
status: Execution status (success/error/validation_error)
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
# Track overall executions without high-cardinality labels
|
||||
GRAPH_EXECUTIONS.labels(status=status).inc()
|
||||
|
||||
# Optionally track per-user executions (implement sampling if needed)
|
||||
# For now, just track status to avoid cardinality explosion
|
||||
GRAPH_EXECUTIONS_BY_USER.labels(status=status).inc()
|
||||
|
||||
|
||||
def record_block_execution(block_type: str, status: str, duration: float):
|
||||
"""Record a block execution event with duration."""
|
||||
BLOCK_EXECUTIONS.labels(block_type=block_type, status=status).inc()
|
||||
BLOCK_DURATION.labels(block_type=block_type).observe(duration)
|
||||
|
||||
|
||||
def update_websocket_connections(user_id: str, delta: int):
|
||||
"""Update the number of active WebSocket connections.
|
||||
|
||||
Args:
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
delta: Change in connection count (+1 for connect, -1 for disconnect)
|
||||
"""
|
||||
# Track total connections without user_id to prevent cardinality explosion
|
||||
if delta > 0:
|
||||
WEBSOCKET_CONNECTIONS.inc(delta)
|
||||
else:
|
||||
WEBSOCKET_CONNECTIONS.dec(abs(delta))
|
||||
|
||||
|
||||
def record_database_query(operation: str, table: str, duration: float):
|
||||
"""Record a database query with duration."""
|
||||
DATABASE_QUERIES.labels(operation=operation, table=table).observe(duration)
|
||||
|
||||
|
||||
def record_rabbitmq_message(queue: str, status: str):
|
||||
"""Record a RabbitMQ message event."""
|
||||
RABBITMQ_MESSAGES.labels(queue=queue, status=status).inc()
|
||||
|
||||
|
||||
def record_authentication_attempt(method: str, status: str):
|
||||
"""Record an authentication attempt."""
|
||||
AUTHENTICATION_ATTEMPTS.labels(method=method, status=status).inc()
|
||||
|
||||
|
||||
def record_api_key_usage(provider: str, block_type: str, status: str):
|
||||
"""Record API key usage by provider and block."""
|
||||
API_KEY_USAGE.labels(provider=provider, block_type=block_type, status=status).inc()
|
||||
|
||||
|
||||
def record_rate_limit_hit(endpoint: str, user_id: str):
|
||||
"""Record a rate limit hit.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint that was rate limited
|
||||
user_id: User identifier (kept for future sampling/debugging)
|
||||
"""
|
||||
RATE_LIMIT_HITS.labels(endpoint=endpoint).inc()
|
||||
|
||||
|
||||
def record_graph_operation(operation: str, status: str):
|
||||
"""Record a graph operation (create, update, delete, execute, etc.)."""
|
||||
GRAPH_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
|
||||
|
||||
def record_user_operation(operation: str, status: str):
|
||||
"""Record a user operation (login, register, etc.)."""
|
||||
USER_OPERATIONS.labels(operation=operation, status=status).inc()
|
||||
@@ -63,7 +63,7 @@ except ImportError:
|
||||
|
||||
# Cost System
|
||||
try:
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
except ImportError:
|
||||
from backend.data.block_cost_config import BlockCost, BlockCostType
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional, Type
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost, BlockCostType
|
||||
from backend.data.block import BlockCost, BlockCostType
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -8,9 +8,8 @@ BLOCK_COSTS configuration used by the execution system.
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from backend.data.block import Block
|
||||
from backend.data.block import Block, BlockCost
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.cost import BlockCost
|
||||
from backend.data.block import BlockCost
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
|
||||
@@ -6,10 +6,10 @@ import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.basic import Block
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
@@ -17,6 +17,8 @@ from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
if TYPE_CHECKING:
|
||||
from backend.sdk.provider import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKOAuthCredentials(BaseModel):
|
||||
"""OAuth credentials configuration for SDK providers."""
|
||||
@@ -102,21 +104,8 @@ class AutoRegistry:
|
||||
"""Register an environment variable as an API key for a provider."""
|
||||
with cls._lock:
|
||||
cls._api_key_mappings[provider] = env_var_name
|
||||
|
||||
# Dynamically check if the env var exists and create credential
|
||||
import os
|
||||
|
||||
api_key = os.getenv(env_var_name)
|
||||
if api_key:
|
||||
credential = APIKeyCredentials(
|
||||
id=f"{provider}-default",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"Default {provider} credentials",
|
||||
)
|
||||
# Check if credential already exists to avoid duplicates
|
||||
if not any(c.id == credential.id for c in cls._default_credentials):
|
||||
cls._default_credentials.append(credential)
|
||||
# Note: The credential itself is created by ProviderBuilder.with_api_key()
|
||||
# We only store the mapping here to avoid duplication
|
||||
|
||||
@classmethod
|
||||
def get_all_credentials(cls) -> List[Credentials]:
|
||||
@@ -210,3 +199,43 @@ class AutoRegistry:
|
||||
webhooks.load_webhook_managers = patched_load
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch webhook managers: {e}")
|
||||
|
||||
# Patch credentials store to include SDK-registered credentials
|
||||
try:
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# Get the module from sys.modules to respect mocking
|
||||
if "backend.integrations.credentials_store" in sys.modules:
|
||||
creds_store: Any = sys.modules["backend.integrations.credentials_store"]
|
||||
else:
|
||||
import backend.integrations.credentials_store
|
||||
|
||||
creds_store: Any = backend.integrations.credentials_store
|
||||
|
||||
if hasattr(creds_store, "IntegrationCredentialsStore"):
|
||||
store_class = creds_store.IntegrationCredentialsStore
|
||||
if hasattr(store_class, "get_all_creds"):
|
||||
original_get_all_creds = store_class.get_all_creds
|
||||
|
||||
async def patched_get_all_creds(self, user_id: str):
|
||||
# Get original credentials
|
||||
original_creds = await original_get_all_creds(self, user_id)
|
||||
|
||||
# Add SDK-registered credentials
|
||||
sdk_creds = cls.get_all_credentials()
|
||||
|
||||
# Combine credentials, avoiding duplicates by ID
|
||||
existing_ids = {c.id for c in original_creds}
|
||||
for cred in sdk_creds:
|
||||
if cred.id not in existing_ids:
|
||||
original_creds.append(cred)
|
||||
|
||||
return original_creds
|
||||
|
||||
store_class.get_all_creds = patched_get_all_creds
|
||||
logger.info(
|
||||
"Successfully patched IntegrationCredentialsStore.get_all_creds"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to patch credentials store: {e}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
|
||||
from .routes.v1 import v1_router
|
||||
@@ -13,3 +14,12 @@ external_app = FastAPI(
|
||||
|
||||
external_app.add_middleware(SecurityHeadersMiddleware)
|
||||
external_app.include_router(v1_router, prefix="/v1")
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
external_app,
|
||||
service_name="external-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=True,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,12 @@ from fastapi import HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from prisma.enums import APIKeyPermission
|
||||
|
||||
from backend.data.api_key import APIKey, has_permission, validate_api_key
|
||||
from backend.data.api_key import APIKeyInfo, has_permission, validate_api_key
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKey:
|
||||
async def require_api_key(api_key: str | None = Security(api_key_header)) -> APIKeyInfo:
|
||||
"""Base middleware for API key authentication"""
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
@@ -23,7 +23,9 @@ async def require_api_key(api_key: str | None = Security(api_key_header)) -> API
|
||||
def require_permission(permission: APIKeyPermission):
|
||||
"""Dependency function for checking specific permissions"""
|
||||
|
||||
async def check_permission(api_key: APIKey = Security(require_api_key)):
|
||||
async def check_permission(
|
||||
api_key: APIKeyInfo = Security(require_api_key),
|
||||
) -> APIKeyInfo:
|
||||
if not has_permission(api_key, permission):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing_extensions import TypedDict
|
||||
import backend.data.block
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import APIKey
|
||||
from backend.data.api_key import APIKeyInfo
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput
|
||||
from backend.executor.utils import add_graph_execution
|
||||
from backend.server.external.middleware import require_permission
|
||||
@@ -62,7 +62,7 @@ def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
async def execute_graph_block(
|
||||
block_id: str,
|
||||
data: BlockInput,
|
||||
api_key: APIKey = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_BLOCK)),
|
||||
) -> CompletedBlockOutput:
|
||||
obj = backend.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
@@ -82,7 +82,7 @@ async def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
api_key: APIKey = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
graph_exec = await add_graph_execution(
|
||||
@@ -104,7 +104,7 @@ async def execute_graph(
|
||||
async def get_graph_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
api_key: APIKey = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
api_key: APIKeyInfo = Security(require_permission(APIKeyPermission.READ_GRAPH)),
|
||||
) -> GraphExecutionResult:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=api_key.user_id)
|
||||
if not graph:
|
||||
|
||||
@@ -81,6 +81,10 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Add noindex header for shared execution pages
|
||||
if "/public/shared" in request.url.path:
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow"
|
||||
|
||||
# Default: Disable caching for all endpoints
|
||||
# Only allow caching for explicitly permitted paths
|
||||
if not self.is_cacheable_path(request.url.path):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
|
||||
from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash
|
||||
from backend.data.api_key import APIKeyInfo, APIKeyPermission
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.timezone_name import TimeZoneName
|
||||
|
||||
@@ -45,7 +45,7 @@ class CreateAPIKeyRequest(pydantic.BaseModel):
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(pydantic.BaseModel):
|
||||
api_key: APIKeyWithoutHash
|
||||
api_key: APIKeyInfo
|
||||
plain_text_key: str
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@ from autogpt_libs.auth import add_auth_responses_to_openapi
|
||||
from autogpt_libs.auth import verify_settings as verify_auth_settings
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.routing import APIRoute
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.routers.postmark.postmark
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.admin.credit_admin_routes
|
||||
@@ -35,10 +37,12 @@ import backend.util.settings
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
from backend.util.cloud_storage import shutdown_cloud_storage_handler
|
||||
from backend.util.exceptions import NotAuthorizedError, NotFoundError
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
@@ -76,6 +80,8 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(LlmModel.GPT4O)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
@@ -137,6 +143,16 @@ app.add_middleware(SecurityHeadersMiddleware)
|
||||
# Add 401 responses to authenticated endpoints in OpenAPI spec
|
||||
add_auth_responses_to_openapi(app)
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
service_name="rest-api",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=settings.config.app_env
|
||||
== backend.util.settings.AppEnvironment.LOCAL,
|
||||
)
|
||||
|
||||
|
||||
def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: fastapi.Request, exc: Exception):
|
||||
@@ -195,10 +211,14 @@ async def validation_error_handler(
|
||||
)
|
||||
|
||||
|
||||
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
|
||||
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
|
||||
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
|
||||
app.add_exception_handler(RequestValidationError, validation_error_handler)
|
||||
app.add_exception_handler(pydantic.ValidationError, validation_error_handler)
|
||||
app.add_exception_handler(ValueError, handle_internal_http_error(400))
|
||||
app.add_exception_handler(Exception, handle_internal_http_error(500))
|
||||
|
||||
app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/api")
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Sequence
|
||||
|
||||
import pydantic
|
||||
@@ -27,20 +29,9 @@ from typing_extensions import Optional, TypedDict
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import (
|
||||
APIKeyError,
|
||||
APIKeyNotFoundError,
|
||||
APIKeyPermissionError,
|
||||
APIKeyWithoutHash,
|
||||
generate_api_key,
|
||||
get_api_key_by_id,
|
||||
list_user_api_keys,
|
||||
revoke_api_key,
|
||||
suspend_api_key,
|
||||
update_api_key_permissions,
|
||||
)
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
@@ -74,6 +65,11 @@ from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||
on_graph_activate,
|
||||
on_graph_deactivate,
|
||||
)
|
||||
from backend.monitoring.instrumentation import (
|
||||
record_block_execution,
|
||||
record_graph_execution,
|
||||
record_graph_operation,
|
||||
)
|
||||
from backend.server.model import (
|
||||
CreateAPIKeyRequest,
|
||||
CreateAPIKeyResponse,
|
||||
@@ -90,7 +86,6 @@ from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.timezone_utils import (
|
||||
convert_cron_to_utc,
|
||||
convert_utc_time_to_user_timezone,
|
||||
get_user_timezone_or_utc,
|
||||
)
|
||||
@@ -108,6 +103,7 @@ def _create_file_size_error(size_bytes: int, max_size_mb: int) -> HTTPException:
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
# Define the API routes
|
||||
@@ -176,7 +172,6 @@ async def get_user_timezone_route(
|
||||
summary="Update user timezone",
|
||||
tags=["auth"],
|
||||
dependencies=[Security(requires_user)],
|
||||
response_model=TimezoneResponse,
|
||||
)
|
||||
async def update_user_timezone_route(
|
||||
user_id: Annotated[str, Security(get_user_id)], request: UpdateTimezoneRequest
|
||||
@@ -292,10 +287,26 @@ async def execute_graph_block(block_id: str, data: BlockInput) -> CompletedBlock
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
start_time = time.time()
|
||||
try:
|
||||
output = defaultdict(list)
|
||||
async for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
|
||||
# Record successful block execution with duration
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(
|
||||
block_type=block_type, status="success", duration=duration
|
||||
)
|
||||
|
||||
return output
|
||||
except Exception:
|
||||
# Record failed block execution
|
||||
duration = time.time() - start_time
|
||||
block_type = obj.__class__.__name__
|
||||
record_block_execution(block_type=block_type, status="error", duration=duration)
|
||||
raise
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -791,7 +802,7 @@ async def execute_graph(
|
||||
)
|
||||
|
||||
try:
|
||||
return await execution_utils.add_graph_execution(
|
||||
result = await execution_utils.add_graph_execution(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
@@ -799,7 +810,16 @@ async def execute_graph(
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
return result
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
record_graph_execution(
|
||||
graph_id=graph_id, status="validation_error", user_id=user_id
|
||||
)
|
||||
record_graph_operation(operation="execute", status="validation_error")
|
||||
# Return structured validation errors that the frontend can parse
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -810,6 +830,11 @@ async def execute_graph(
|
||||
"node_errors": e.node_errors,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
# Record any other failures
|
||||
record_graph_execution(graph_id=graph_id, status="error", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="error")
|
||||
raise
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -934,6 +959,99 @@ async def delete_graph_execution(
|
||||
)
|
||||
|
||||
|
||||
class ShareRequest(pydantic.BaseModel):
|
||||
"""Optional request body for share endpoint."""
|
||||
|
||||
pass # Empty body is fine
|
||||
|
||||
|
||||
class ShareResponse(pydantic.BaseModel):
|
||||
"""Response from share endpoints."""
|
||||
|
||||
share_url: str
|
||||
share_token: str
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def enable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
_body: ShareRequest = Body(default=ShareRequest()),
|
||||
) -> ShareResponse:
|
||||
"""Enable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Generate a unique share token
|
||||
share_token = str(uuid.uuid4())
|
||||
|
||||
# Update the execution with share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=True,
|
||||
share_token=share_token,
|
||||
shared_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Return the share URL
|
||||
frontend_url = Settings().config.frontend_base_url or "http://localhost:3000"
|
||||
share_url = f"{frontend_url}/share/{share_token}"
|
||||
|
||||
return ShareResponse(share_url=share_url, share_token=share_token)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/graphs/{graph_id}/executions/{graph_exec_id}/share",
|
||||
status_code=HTTP_204_NO_CONTENT,
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def disable_execution_sharing(
|
||||
graph_id: Annotated[str, Path],
|
||||
graph_exec_id: Annotated[str, Path],
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> None:
|
||||
"""Disable sharing for a graph execution."""
|
||||
# Verify the execution belongs to the user
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Execution not found")
|
||||
|
||||
# Remove share info
|
||||
await execution_db.update_graph_execution_share_status(
|
||||
execution_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
is_shared=False,
|
||||
share_token=None,
|
||||
shared_at=None,
|
||||
)
|
||||
|
||||
|
||||
@v1_router.get("/public/shared/{share_token}")
|
||||
async def get_shared_execution(
|
||||
share_token: Annotated[
|
||||
str,
|
||||
Path(regex=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"),
|
||||
],
|
||||
) -> execution_db.SharedExecutionResponse:
|
||||
"""Get a shared graph execution by share token (no auth required)."""
|
||||
execution = await execution_db.get_graph_execution_by_share_token(share_token)
|
||||
if not execution:
|
||||
raise HTTPException(status_code=404, detail="Shared execution not found")
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
@@ -945,6 +1063,10 @@ class ScheduleCreationRequest(pydantic.BaseModel):
|
||||
cron: str
|
||||
inputs: dict[str, Any]
|
||||
credentials: dict[str, CredentialsMetaInput] = pydantic.Field(default_factory=dict)
|
||||
timezone: Optional[str] = pydantic.Field(
|
||||
default=None,
|
||||
description="User's timezone for scheduling (e.g., 'America/New_York'). If not provided, will use user's saved timezone or UTC.",
|
||||
)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -969,26 +1091,22 @@ async def create_graph_execution_schedule(
|
||||
detail=f"Graph #{graph_id} v{schedule_params.graph_version} not found.",
|
||||
)
|
||||
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert cron expression from user timezone to UTC
|
||||
try:
|
||||
utc_cron = convert_cron_to_utc(schedule_params.cron, user_timezone)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid cron expression for timezone {user_timezone}: {e}",
|
||||
)
|
||||
# Use timezone from request if provided, otherwise fetch from user profile
|
||||
if schedule_params.timezone:
|
||||
user_timezone = schedule_params.timezone
|
||||
else:
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
result = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
name=schedule_params.name,
|
||||
cron=utc_cron, # Send UTC cron to scheduler
|
||||
cron=schedule_params.cron,
|
||||
input_data=schedule_params.inputs,
|
||||
input_credentials=schedule_params.credentials,
|
||||
user_timezone=user_timezone,
|
||||
)
|
||||
|
||||
# Convert the next_run_time back to user timezone for display
|
||||
@@ -1010,24 +1128,11 @@ async def list_graph_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
graph_id: str = Path(),
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
schedules = await get_scheduler_client().get_execution_schedules(
|
||||
return await get_scheduler_client().get_execution_schedules(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/schedules",
|
||||
@@ -1038,20 +1143,7 @@ async def list_graph_execution_schedules(
|
||||
async def list_all_graphs_execution_schedules(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[scheduler.GraphExecutionJobInfo]:
|
||||
schedules = await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
# Get user timezone for conversion
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
# Convert UTC next_run_time to user timezone for display
|
||||
for schedule in schedules:
|
||||
if schedule.next_run_time:
|
||||
schedule.next_run_time = convert_utc_time_to_user_timezone(
|
||||
schedule.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
return schedules
|
||||
return await get_scheduler_client().get_execution_schedules(user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
@@ -1082,7 +1174,6 @@ async def delete_graph_execution_schedule(
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
summary="Create new API key",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1090,128 +1181,73 @@ async def create_api_key(
|
||||
request: CreateAPIKeyRequest, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> CreateAPIKeyResponse:
|
||||
"""Create a new API key"""
|
||||
try:
|
||||
api_key, plain_text = await generate_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key, plain_text_key=plain_text)
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Could not create API key for user %s: %s. Review input and permissions.",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Verify request payload and try again."},
|
||||
)
|
||||
api_key_info, plain_text_key = await api_key_db.create_api_key(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
permissions=request.permissions,
|
||||
description=request.description,
|
||||
)
|
||||
return CreateAPIKeyResponse(api_key=api_key_info, plain_text_key=plain_text_key)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys",
|
||||
summary="List user API keys",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_keys(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[APIKeyWithoutHash]:
|
||||
) -> list[api_key_db.APIKeyInfo]:
|
||||
"""List all API keys for the user"""
|
||||
try:
|
||||
return await list_user_api_keys(user_id)
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to list API keys for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check API key service availability."},
|
||||
)
|
||||
return await api_key_db.list_user_api_keys(user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Get specific API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> APIKeyWithoutHash:
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
"""Get a specific API key"""
|
||||
try:
|
||||
api_key = await get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
except APIKeyError as e:
|
||||
logger.error("Error retrieving API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure the key ID is correct."},
|
||||
)
|
||||
api_key = await api_key_db.get_api_key_by_id(key_id, user_id)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
return api_key
|
||||
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
summary="Revoke API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def delete_api_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
"""Revoke an API key"""
|
||||
try:
|
||||
return await revoke_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to revoke API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": str(e),
|
||||
"hint": "Verify permissions or try again later.",
|
||||
},
|
||||
)
|
||||
return await api_key_db.revoke_api_key(key_id, user_id)
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
summary="Suspend API key",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def suspend_key(
|
||||
key_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
"""Suspend an API key"""
|
||||
try:
|
||||
return await suspend_api_key(key_id, user_id)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error("Failed to suspend API key %s for user %s: %s", key_id, user_id, e)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Check user permissions and retry."},
|
||||
)
|
||||
return await api_key_db.suspend_api_key(key_id, user_id)
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
summary="Update key permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
@@ -1219,22 +1255,8 @@ async def update_permissions(
|
||||
key_id: str,
|
||||
request: UpdatePermissionsRequest,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Optional[APIKeyWithoutHash]:
|
||||
) -> api_key_db.APIKeyInfo:
|
||||
"""Update API key permissions"""
|
||||
try:
|
||||
return await update_api_key_permissions(key_id, user_id, request.permissions)
|
||||
except APIKeyNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
except APIKeyPermissionError:
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
except APIKeyError as e:
|
||||
logger.error(
|
||||
"Failed to update permissions for API key %s of user %s: %s",
|
||||
key_id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": str(e), "hint": "Ensure permissions list is valid."},
|
||||
)
|
||||
return await api_key_db.update_api_key_permissions(
|
||||
key_id, user_id, request.permissions
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@@ -265,6 +266,7 @@ def test_get_graphs(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -299,6 +301,7 @@ def test_get_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
@@ -348,6 +351,7 @@ def test_delete_graph(
|
||||
name="Test Graph",
|
||||
description="A test graph",
|
||||
user_id=test_user_id,
|
||||
created_at=datetime(2025, 9, 4, 13, 37),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
|
||||
@@ -3,8 +3,9 @@ API Key authentication utilities for FastAPI applications.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
@@ -12,6 +13,8 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIKeyAuthenticator(APIKeyHeader):
|
||||
"""
|
||||
@@ -51,7 +54,8 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
header_name (str): The name of the header containing the API key
|
||||
expected_token (Optional[str]): The expected API key value for simple token matching
|
||||
validator (Optional[Callable]): Custom validation function that takes an API key
|
||||
string and returns a boolean or object. Can be async.
|
||||
string and returns a truthy value if and only if the passed string is a
|
||||
valid API key. Can be async.
|
||||
status_if_missing (int): HTTP status code to use for validation errors
|
||||
message_if_invalid (str): Error message to return when validation fails
|
||||
"""
|
||||
@@ -60,7 +64,9 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self,
|
||||
header_name: str,
|
||||
expected_token: Optional[str] = None,
|
||||
validator: Optional[Callable[[str], bool]] = None,
|
||||
validator: Optional[
|
||||
Callable[[str], Any] | Callable[[str], Awaitable[Any]]
|
||||
] = None,
|
||||
status_if_missing: int = HTTP_401_UNAUTHORIZED,
|
||||
message_if_invalid: str = "Invalid API key",
|
||||
):
|
||||
@@ -75,7 +81,7 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
self.message_if_invalid = message_if_invalid
|
||||
|
||||
async def __call__(self, request: Request) -> Any:
|
||||
api_key = await super()(request)
|
||||
api_key = await super().__call__(request)
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=self.status_if_missing, detail="No API key in request"
|
||||
@@ -106,4 +112,9 @@ class APIKeyAuthenticator(APIKeyHeader):
|
||||
f"{self.__class__.__name__}.expected_token is not set; "
|
||||
"either specify it or provide a custom validator"
|
||||
)
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
try:
|
||||
return secrets.compare_digest(api_key, self.expected_token)
|
||||
except TypeError as e:
|
||||
# If value is not an ASCII string, compare_digest raises a TypeError
|
||||
logger.warning(f"{self.model.name} API key check failed: {e}")
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
Unit tests for APIKeyAuthenticator class.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
from backend.server.utils.api_key_auth import APIKeyAuthenticator
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock()
|
||||
request.headers = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth():
|
||||
"""Create a basic APIKeyAuthenticator instance."""
|
||||
return APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="test-secret-token"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_custom_validator():
|
||||
"""Create APIKeyAuthenticator with custom validator."""
|
||||
|
||||
def custom_validator(api_key: str) -> bool:
|
||||
return api_key == "custom-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=custom_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_async_validator():
|
||||
"""Create APIKeyAuthenticator with async custom validator."""
|
||||
|
||||
async def async_validator(api_key: str) -> bool:
|
||||
return api_key == "async-valid-key"
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=async_validator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_auth_object_validator():
|
||||
"""Create APIKeyAuthenticator that returns objects from validator."""
|
||||
|
||||
async def object_validator(api_key: str):
|
||||
if api_key == "user-key":
|
||||
return {"user_id": "123", "permissions": ["read", "write"]}
|
||||
return None
|
||||
|
||||
return APIKeyAuthenticator(header_name="X-API-Key", validator=object_validator)
|
||||
|
||||
|
||||
# ========== Basic Initialization Tests ========== #
|
||||
|
||||
|
||||
def test_init_with_expected_token():
|
||||
"""Test initialization with expected token."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="test-token")
|
||||
|
||||
assert auth.model.name == "X-API-Key"
|
||||
assert auth.expected_token == "test-token"
|
||||
assert auth.custom_validator is None
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_validator():
|
||||
"""Test initialization with custom validator."""
|
||||
|
||||
def validator(key: str) -> bool:
|
||||
return True
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="Authorization", validator=validator)
|
||||
|
||||
assert auth.model.name == "Authorization"
|
||||
assert auth.expected_token is None
|
||||
assert auth.custom_validator == validator
|
||||
assert auth.status_if_missing == HTTP_401_UNAUTHORIZED
|
||||
assert auth.message_if_invalid == "Invalid API key"
|
||||
|
||||
|
||||
def test_init_with_custom_parameters():
|
||||
"""Test initialization with custom status and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-Custom-Key",
|
||||
expected_token="token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access denied",
|
||||
)
|
||||
|
||||
assert auth.model.name == "X-Custom-Key"
|
||||
assert auth.status_if_missing == HTTP_403_FORBIDDEN
|
||||
assert auth.message_if_invalid == "Access denied"
|
||||
|
||||
|
||||
def test_scheme_name_generation():
|
||||
"""Test that scheme_name is generated correctly."""
|
||||
auth = APIKeyAuthenticator(header_name="X-Custom-Header", expected_token="token")
|
||||
|
||||
assert auth.scheme_name == "APIKeyAuthenticator-X-Custom-Header"
|
||||
|
||||
|
||||
# ========== Authentication Flow Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_missing(api_key_auth, mock_request):
|
||||
"""Test behavior when API key is missing from request."""
|
||||
# Mock the parent class method to return None (no API key)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_valid(api_key_auth, mock_request):
|
||||
"""Test behavior with valid API key."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="test-secret-token",
|
||||
):
|
||||
result = await api_key_auth(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_invalid(api_key_auth, mock_request):
|
||||
"""Test behavior with invalid API key."""
|
||||
# Mock the parent class to return an invalid API key
|
||||
with patch.object(
|
||||
api_key_auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
# ========== Custom Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_status_and_message(mock_request):
|
||||
"""Test custom status code and message."""
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key",
|
||||
expected_token="valid-token",
|
||||
status_if_missing=HTTP_403_FORBIDDEN,
|
||||
message_if_invalid="Access forbidden",
|
||||
)
|
||||
|
||||
# Test missing key
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "No API key in request"
|
||||
|
||||
# Test invalid key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="invalid-token"
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_403_FORBIDDEN
|
||||
assert exc_info.value.detail == "Access forbidden"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator(api_key_auth_custom_validator, mock_request):
|
||||
"""Test with custom synchronous validator."""
|
||||
# Mock the parent class to return the API key
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="custom-valid-key",
|
||||
):
|
||||
result = await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_sync_validator_invalid(
|
||||
api_key_auth_custom_validator, mock_request
|
||||
):
|
||||
"""Test custom synchronous validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_custom_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_custom_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator(api_key_auth_async_validator, mock_request):
|
||||
"""Test with custom async validator."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="async-valid-key",
|
||||
):
|
||||
result = await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_async_validator_invalid(
|
||||
api_key_auth_async_validator, mock_request
|
||||
):
|
||||
"""Test custom async validator with invalid key."""
|
||||
with patch.object(
|
||||
api_key_auth_async_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_async_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_object(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns an object instead of boolean."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="user-key",
|
||||
):
|
||||
result = await api_key_auth_object_validator(mock_request)
|
||||
|
||||
expected_result = {"user_id": "123", "permissions": ["read", "write"]}
|
||||
assert result == expected_result
|
||||
# Verify the object is stored in request state
|
||||
assert mock_request.state.api_key == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_returns_none(api_key_auth_object_validator, mock_request):
|
||||
"""Test validator that returns None (falsy)."""
|
||||
with patch.object(
|
||||
api_key_auth_object_validator.__class__.__bases__[0],
|
||||
"__call__",
|
||||
return_value="invalid-key",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await api_key_auth_object_validator(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validator_database_lookup_simulation(mock_request):
|
||||
"""Test simulation of database lookup validator."""
|
||||
# Simulate database records
|
||||
valid_api_keys = {
|
||||
"key123": {"user_id": "user1", "active": True},
|
||||
"key456": {"user_id": "user2", "active": False},
|
||||
}
|
||||
|
||||
async def db_validator(api_key: str):
|
||||
record = valid_api_keys.get(api_key)
|
||||
return record if record and record["active"] else None
|
||||
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", validator=db_validator)
|
||||
|
||||
# Test valid active key
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key123"):
|
||||
result = await auth(mock_request)
|
||||
assert result == {"user_id": "user1", "active": True}
|
||||
assert mock_request.state.api_key == {"user_id": "user1", "active": True}
|
||||
|
||||
# Test inactive key
|
||||
mock_request.state = Mock() # Reset state
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value="key456"):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
# Test non-existent key
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value="nonexistent"
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
await auth(mock_request)
|
||||
|
||||
|
||||
# ========== Default Validator Tests ========== #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_valid(api_key_auth):
|
||||
"""Test default validator with valid token."""
|
||||
result = await api_key_auth.default_validator("test-secret-token")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_key_invalid(api_key_auth):
|
||||
"""Test default validator with invalid token."""
|
||||
result = await api_key_auth.default_validator("wrong-token")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_missing_expected_token():
|
||||
"""Test default validator when expected_token is not set."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key")
|
||||
|
||||
with pytest.raises(MissingConfigError) as exc_info:
|
||||
await auth.default_validator("any-token")
|
||||
|
||||
assert "expected_token is not set" in str(exc_info.value)
|
||||
assert "either specify it or provide a custom validator" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_validator_uses_constant_time_comparison(api_key_auth):
|
||||
"""
|
||||
Test that default validator uses secrets.compare_digest for timing attack protection
|
||||
"""
|
||||
with patch("secrets.compare_digest") as mock_compare:
|
||||
mock_compare.return_value = True
|
||||
|
||||
await api_key_auth.default_validator("test-token")
|
||||
|
||||
mock_compare.assert_called_once_with("test-token", "test-secret-token")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_empty(mock_request):
|
||||
"""Test behavior with empty string API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(auth.__class__.__bases__[0], "__call__", return_value=""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_whitespace_only(mock_request):
|
||||
"""Test behavior with whitespace-only API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=" \t\n "
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_very_long(mock_request):
|
||||
"""Test behavior with extremely long API key (potential DoS protection)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Create a very long API key (10MB)
|
||||
long_api_key = "a" * (10 * 1024 * 1024)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=long_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_null_bytes(mock_request):
|
||||
"""Test behavior with API key containing null bytes."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
api_key_with_null = "valid\x00token"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_null
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_control_characters(mock_request):
|
||||
"""Test behavior with API key containing control characters."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with various control characters
|
||||
api_key_with_control = "valid\r\n\t\x1b[31mtoken"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key_with_control
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters(mock_request):
|
||||
"""Test behavior with Unicode characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# API key with Unicode characters
|
||||
unicode_api_key = "validтокен🔑"
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=unicode_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_unicode_characters_normalization_attack(mock_request):
|
||||
"""Test that Unicode normalization doesn't bypass validation."""
|
||||
# Create auth with composed Unicode character
|
||||
auth = APIKeyAuthenticator(
|
||||
header_name="X-API-Key", expected_token="café" # é is composed
|
||||
)
|
||||
|
||||
# Try with decomposed version (c + a + f + e + ´)
|
||||
decomposed_key = "cafe\u0301" # é as combining character
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=decomposed_key
|
||||
):
|
||||
# Should fail because secrets.compare_digest doesn't normalize
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_binary_data(mock_request):
|
||||
"""Test behavior with binary data in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Binary data that might cause encoding issues
|
||||
binary_api_key = bytes([0xFF, 0xFE, 0xFD, 0xFC, 0x80, 0x81]).decode(
|
||||
"latin1", errors="ignore"
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=binary_api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_with_regex_dos_attack_pattern(mock_request):
|
||||
"""Test behavior with API key of repeated characters (pattern attack)."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
# Pattern that might cause regex DoS in poorly implemented validators
|
||||
repeated_key = "a" * 1000 + "b" * 1000 + "c" * 1000
|
||||
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=repeated_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_keys_with_newline_variations(mock_request):
|
||||
"""Test different newline characters in API key."""
|
||||
auth = APIKeyAuthenticator(header_name="X-API-Key", expected_token="valid-token")
|
||||
|
||||
newline_variations = [
|
||||
"valid\ntoken", # Unix newline
|
||||
"valid\r\ntoken", # Windows newline
|
||||
"valid\rtoken", # Mac newline
|
||||
"valid\x85token", # NEL (Next Line)
|
||||
"valid\x0Btoken", # Vertical Tab
|
||||
"valid\x0Ctoken", # Form Feed
|
||||
]
|
||||
|
||||
for api_key in newline_variations:
|
||||
with patch.object(
|
||||
auth.__class__.__bases__[0], "__call__", return_value=api_key
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await auth(mock_request)
|
||||
|
||||
assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED
|
||||
assert exc_info.value.detail == "Invalid API key"
|
||||
@@ -7,12 +7,10 @@ import prisma
|
||||
import backend.data.block
|
||||
from backend.blocks import load_all_blocks
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.block import Block, BlockCategory, BlockSchema
|
||||
from backend.data.credit import get_block_costs
|
||||
from backend.data.block import Block, BlockCategory, BlockInfo, BlockSchema
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.v2.builder.model import (
|
||||
BlockCategoryResponse,
|
||||
BlockData,
|
||||
BlockResponse,
|
||||
BlockType,
|
||||
CountResponse,
|
||||
@@ -25,7 +23,7 @@ from backend.util.models import Pagination
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
_static_counts_cache: dict | None = None
|
||||
_suggested_blocks: list[BlockData] | None = None
|
||||
_suggested_blocks: list[BlockInfo] | None = None
|
||||
|
||||
|
||||
def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse]:
|
||||
@@ -53,7 +51,7 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
|
||||
|
||||
# Append if the category has less than the specified number of blocks
|
||||
if len(categories[category].blocks) < category_blocks:
|
||||
categories[category].blocks.append(block.to_dict())
|
||||
categories[category].blocks.append(block.get_info())
|
||||
|
||||
# Sort categories by name
|
||||
return sorted(categories.values(), key=lambda x: x.name)
|
||||
@@ -109,10 +107,8 @@ def get_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -174,11 +170,9 @@ def search_blocks(
|
||||
take -= 1
|
||||
blocks.append(block)
|
||||
|
||||
costs = get_block_costs()
|
||||
|
||||
return SearchBlocksResponse(
|
||||
blocks=BlockResponse(
|
||||
blocks=[{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks],
|
||||
blocks=[b.get_info() for b in blocks],
|
||||
pagination=Pagination(
|
||||
total_items=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
@@ -323,7 +317,7 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
|
||||
return providers
|
||||
|
||||
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
|
||||
global _suggested_blocks
|
||||
|
||||
if _suggested_blocks is not None and len(_suggested_blocks) >= count:
|
||||
@@ -351,7 +345,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
|
||||
# Get the top blocks based on execution count
|
||||
# But ignore Input and Output blocks
|
||||
blocks: list[tuple[BlockData, int]] = []
|
||||
blocks: list[tuple[BlockInfo, int]] = []
|
||||
|
||||
for block_type in load_all_blocks().values():
|
||||
block: Block[BlockSchema, BlockSchema] = block_type()
|
||||
@@ -366,7 +360,7 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockData]:
|
||||
(row["execution_count"] for row in results if row["block_id"] == block.id),
|
||||
0,
|
||||
)
|
||||
blocks.append((block.to_dict(), execution_count))
|
||||
blocks.append((block.get_info(), execution_count))
|
||||
# Sort blocks by execution count
|
||||
blocks.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.model as store_model
|
||||
from backend.data.block import BlockInfo
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.models import Pagination
|
||||
|
||||
@@ -16,29 +17,27 @@ FilterType = Literal[
|
||||
|
||||
BlockType = Literal["all", "input", "action", "output"]
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
|
||||
# Suggestions
|
||||
class SuggestionsResponse(BaseModel):
|
||||
otto_suggestions: list[str]
|
||||
recent_searches: list[str]
|
||||
providers: list[ProviderName]
|
||||
top_blocks: list[BlockData]
|
||||
top_blocks: list[BlockInfo]
|
||||
|
||||
|
||||
# All blocks
|
||||
class BlockCategoryResponse(BaseModel):
|
||||
name: str
|
||||
total_blocks: int
|
||||
blocks: list[BlockData]
|
||||
blocks: list[BlockInfo]
|
||||
|
||||
model_config = {"use_enum_values": False} # <== use enum names like "AI"
|
||||
|
||||
|
||||
# Input/Action/Output and see all for block categories
|
||||
class BlockResponse(BaseModel):
|
||||
blocks: list[BlockData]
|
||||
blocks: list[BlockInfo]
|
||||
pagination: Pagination
|
||||
|
||||
|
||||
@@ -71,7 +70,7 @@ class SearchBlocksResponse(BaseModel):
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
items: list[BlockData | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
items: list[BlockInfo | library_model.LibraryAgent | store_model.StoreAgent]
|
||||
total_items: dict[FilterType, int]
|
||||
page: int
|
||||
more_pages: bool
|
||||
|
||||
@@ -144,6 +144,92 @@ async def list_library_agents(
|
||||
raise store_exceptions.DatabaseError("Failed to fetch library agents") from e
|
||||
|
||||
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of favorite LibraryAgent records for a given user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose favorite LibraryAgents we want to retrieve.
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of favorite agents and pagination details.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If there is an issue fetching from Prisma.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Fetching favorite library agents for user_id={user_id}, "
|
||||
f"page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
if page < 1 or page_size < 1:
|
||||
logger.warning(f"Invalid pagination: page={page}, page_size={page_size}")
|
||||
raise store_exceptions.DatabaseError("Invalid pagination input")
|
||||
|
||||
where_clause: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
"isArchived": False,
|
||||
"isFavorite": True, # Only fetch favorites
|
||||
}
|
||||
|
||||
# Sort favorites by updated date descending
|
||||
order_by: prisma.types.LibraryAgentOrderByInput = {"updatedAt": "desc"}
|
||||
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(user_id),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
agent_count = await prisma.models.LibraryAgent.prisma().count(
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}"
|
||||
)
|
||||
|
||||
# Only pass valid agents to the response
|
||||
valid_library_agents: list[library_model.LibraryAgent] = []
|
||||
|
||||
for agent in library_agents:
|
||||
try:
|
||||
library_agent = library_model.LibraryAgent.from_db(agent)
|
||||
valid_library_agents.append(library_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
logger.error(
|
||||
f"Error parsing LibraryAgent #{agent.id} from DB item: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Return the response with only valid agents
|
||||
return library_model.LibraryAgentResponse(
|
||||
agents=valid_library_agents,
|
||||
pagination=Pagination(
|
||||
total_items=agent_count,
|
||||
total_pages=(agent_count + page_size - 1) // page_size,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching favorite library agents: {e}")
|
||||
raise store_exceptions.DatabaseError(
|
||||
"Failed to fetch favorite library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_library_agent(id: str, user_id: str) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Get a specific agent from the user's library.
|
||||
@@ -709,10 +795,7 @@ async def create_preset(
|
||||
)
|
||||
for name, data in {
|
||||
**preset.inputs,
|
||||
**{
|
||||
key: creds_meta.model_dump(exclude_none=True)
|
||||
for key, creds_meta in preset.credentials.items()
|
||||
},
|
||||
**preset.credentials,
|
||||
}.items()
|
||||
]
|
||||
},
|
||||
|
||||
@@ -43,6 +43,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
|
||||
name: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
@@ -64,6 +65,12 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
agent: prisma.models.LibraryAgent,
|
||||
@@ -120,6 +127,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
updated_at=updated_at,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
instructions=graph.instructions,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
credentials_input_schema=(
|
||||
@@ -130,6 +138,8 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
)
|
||||
|
||||
|
||||
@@ -253,6 +263,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
webhook: "Webhook | None"
|
||||
@@ -282,6 +293,7 @@ class LibraryAgentPreset(LibraryAgentPresetCreatable):
|
||||
return cls(
|
||||
id=preset.id,
|
||||
user_id=preset.userId,
|
||||
created_at=preset.createdAt,
|
||||
updated_at=preset.updatedAt,
|
||||
graph_id=preset.agentGraphId,
|
||||
graph_version=preset.agentGraphVersion,
|
||||
|
||||
@@ -79,6 +79,54 @@ async def list_library_agents(
|
||||
) from e
|
||||
|
||||
|
||||
@router.get(
|
||||
"/favorites",
|
||||
summary="List Favorite Library Agents",
|
||||
responses={
|
||||
500: {"description": "Server error", "content": {"application/json": {}}},
|
||||
},
|
||||
)
|
||||
async def list_favorite_library_agents(
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
page: int = Query(
|
||||
1,
|
||||
ge=1,
|
||||
description="Page number to retrieve (must be >= 1)",
|
||||
),
|
||||
page_size: int = Query(
|
||||
15,
|
||||
ge=1,
|
||||
description="Number of agents per page (must be >= 1)",
|
||||
),
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Get all favorite agents in the user's library.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user.
|
||||
page: Page number to retrieve.
|
||||
page_size: Number of agents per page.
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing favorite agents and pagination metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list favorite library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/{library_agent_id}", summary="Get Library Agent")
|
||||
async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
|
||||
@@ -50,9 +50,11 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
library_model.LibraryAgent(
|
||||
@@ -69,9 +71,11 @@ async def test_get_library_agents_success(
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=False,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
@@ -119,6 +123,76 @@ def test_get_library_agents_error(mocker: pytest_mock.MockFixture, test_user_id:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_favorite_library_agents_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocked_value = library_model.LibraryAgentResponse(
|
||||
agents=[
|
||||
library_model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
graph_id="test-agent-1",
|
||||
graph_version=1,
|
||||
name="Favorite Agent 1",
|
||||
description="Test Favorite Description 1",
|
||||
image_url=None,
|
||||
creator_name="Test Creator",
|
||||
creator_image_url="",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=True,
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
),
|
||||
],
|
||||
pagination=Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=15
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = library_model.LibraryAgentResponse.model_validate(response.json())
|
||||
assert len(data.agents) == 1
|
||||
assert data.agents[0].is_favorite is True
|
||||
assert data.agents[0].name == "Favorite Agent 1"
|
||||
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
)
|
||||
|
||||
|
||||
def test_add_agent_to_library_success(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
@@ -139,6 +213,7 @@ def test_add_agent_to_library_success(
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
is_latest_version=True,
|
||||
is_favorite=False,
|
||||
updated_at=FIXED_NOW,
|
||||
)
|
||||
|
||||
|
||||
@@ -183,6 +183,36 @@ async def get_store_agent_details(
|
||||
store_listing.hasApprovedVersion if store_listing else False
|
||||
)
|
||||
|
||||
if active_version_id:
|
||||
agent_by_active = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": active_version_id}
|
||||
)
|
||||
if agent_by_active:
|
||||
agent = agent_by_active
|
||||
elif store_listing:
|
||||
latest_approved = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"storeListingId": store_listing.id,
|
||||
"submissionStatus": prisma.enums.SubmissionStatus.APPROVED,
|
||||
},
|
||||
order=[{"version": "desc"}],
|
||||
)
|
||||
)
|
||||
if latest_approved:
|
||||
agent_latest = await prisma.models.StoreAgent.prisma().find_first(
|
||||
where={"storeListingVersionId": latest_approved.id}
|
||||
)
|
||||
if agent_latest:
|
||||
agent = agent_latest
|
||||
|
||||
if store_listing and store_listing.ActiveVersion:
|
||||
recommended_schedule_cron = (
|
||||
store_listing.ActiveVersion.recommendedScheduleCron
|
||||
)
|
||||
else:
|
||||
recommended_schedule_cron = None
|
||||
|
||||
logger.debug(f"Found agent details for {username}/{agent_name}")
|
||||
return backend.server.v2.store.model.StoreAgentDetails(
|
||||
store_listing_version_id=agent.storeListingVersionId,
|
||||
@@ -201,6 +231,7 @@ async def get_store_agent_details(
|
||||
last_updated=agent.updated_at,
|
||||
active_version_id=active_version_id,
|
||||
has_approved_version=has_approved_version,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
)
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
@@ -468,6 +499,7 @@ async def get_store_submissions(
|
||||
sub_heading=sub.sub_heading,
|
||||
slug=sub.slug,
|
||||
description=sub.description,
|
||||
instructions=getattr(sub, "instructions", None),
|
||||
image_urls=sub.image_urls or [],
|
||||
date_submitted=sub.date_submitted or datetime.now(tz=timezone.utc),
|
||||
status=sub.status,
|
||||
@@ -559,9 +591,11 @@ async def create_store_submission(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial Submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create the first (and only) store listing and thus submission as a normal user
|
||||
@@ -629,6 +663,7 @@ async def create_store_submission(
|
||||
video_url=video_url,
|
||||
image_urls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
@@ -650,11 +685,13 @@ async def create_store_submission(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(tz=timezone.utc),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
)
|
||||
]
|
||||
},
|
||||
@@ -679,6 +716,7 @@ async def create_store_submission(
|
||||
slug=slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=listing.createdAt,
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -710,6 +748,8 @@ async def edit_store_submission(
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Update submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
instructions: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Edit an existing store listing submission.
|
||||
@@ -789,6 +829,8 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
categories=categories,
|
||||
changes_summary=changes_summary,
|
||||
recommended_schedule_cron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
# For PENDING submissions, we can update the existing version
|
||||
@@ -804,6 +846,8 @@ async def edit_store_submission(
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
instructions=instructions,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -822,6 +866,7 @@ async def edit_store_submission(
|
||||
sub_heading=sub_heading,
|
||||
slug=current_version.StoreListing.slug,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=updated_version.submittedAt or updated_version.createdAt,
|
||||
status=updated_version.submissionStatus,
|
||||
@@ -863,9 +908,11 @@ async def create_store_version(
|
||||
video_url: str | None = None,
|
||||
image_urls: list[str] = [],
|
||||
description: str = "",
|
||||
instructions: str | None = None,
|
||||
sub_heading: str = "",
|
||||
categories: list[str] = [],
|
||||
changes_summary: str | None = "Initial submission",
|
||||
recommended_schedule_cron: str | None = None,
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new version for an existing store listing
|
||||
@@ -930,11 +977,13 @@ async def create_store_version(
|
||||
videoUrl=video_url,
|
||||
imageUrls=image_urls,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
categories=categories,
|
||||
subHeading=sub_heading,
|
||||
submissionStatus=prisma.enums.SubmissionStatus.PENDING,
|
||||
submittedAt=datetime.now(),
|
||||
changesSummary=changes_summary,
|
||||
recommendedScheduleCron=recommended_schedule_cron,
|
||||
storeListingId=store_listing_id,
|
||||
)
|
||||
)
|
||||
@@ -950,6 +999,7 @@ async def create_store_version(
|
||||
slug=listing.slug,
|
||||
sub_heading=sub_heading,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
image_urls=image_urls,
|
||||
date_submitted=datetime.now(),
|
||||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
@@ -1126,7 +1176,20 @@ async def get_my_agents(
|
||||
try:
|
||||
search_filter: prisma.types.LibraryAgentWhereInput = {
|
||||
"userId": user_id,
|
||||
"AgentGraph": {"is": {"StoreListings": {"none": {"isDeleted": False}}}},
|
||||
"AgentGraph": {
|
||||
"is": {
|
||||
"StoreListings": {
|
||||
"none": {
|
||||
"isDeleted": False,
|
||||
"Versions": {
|
||||
"some": {
|
||||
"isAvailable": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"isArchived": False,
|
||||
"isDeleted": False,
|
||||
}
|
||||
@@ -1150,6 +1213,7 @@ async def get_my_agents(
|
||||
last_edited=graph.updatedAt or graph.createdAt,
|
||||
description=graph.description or "",
|
||||
agent_image=library_agent.imageUrl,
|
||||
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||
)
|
||||
for library_agent in library_agents
|
||||
if (graph := library_agent.AgentGraph)
|
||||
@@ -1351,6 +1415,22 @@ async def review_store_submission(
|
||||
]
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
"version": store_listing_version.agentGraphVersion,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"name": store_listing_version.name,
|
||||
"description": store_listing_version.description,
|
||||
"recommendedScheduleCron": store_listing_version.recommendedScheduleCron,
|
||||
"instructions": store_listing_version.instructions,
|
||||
},
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
@@ -1513,6 +1593,7 @@ async def review_store_submission(
|
||||
else ""
|
||||
),
|
||||
description=submission.description,
|
||||
instructions=submission.instructions,
|
||||
image_urls=submission.imageUrls or [],
|
||||
date_submitted=submission.submittedAt or submission.createdAt,
|
||||
status=submission.submissionStatus,
|
||||
@@ -1648,6 +1729,7 @@ async def get_admin_listings_with_versions(
|
||||
sub_heading=version.subHeading,
|
||||
slug=listing.slug,
|
||||
description=version.description,
|
||||
instructions=version.instructions,
|
||||
image_urls=version.imageUrls or [],
|
||||
date_submitted=version.submittedAt or version.createdAt,
|
||||
status=version.submissionStatus,
|
||||
|
||||
@@ -86,14 +86,50 @@ async def test_get_store_agent_details(mocker):
|
||||
is_available=False,
|
||||
)
|
||||
|
||||
# Mock active version agent (what we want to return for active version)
|
||||
mock_active_agent = prisma.models.StoreAgent(
|
||||
listing_id="test-id",
|
||||
storeListingVersionId="active-version-id",
|
||||
slug="test-agent",
|
||||
agent_name="Test Agent Active",
|
||||
agent_video="active_video.mp4",
|
||||
agent_image=["active_image.jpg"],
|
||||
featured=False,
|
||||
creator_username="creator",
|
||||
creator_avatar="avatar.jpg",
|
||||
sub_heading="Test heading active",
|
||||
description="Test description active",
|
||||
categories=["test"],
|
||||
runs=15,
|
||||
rating=4.8,
|
||||
versions=["1.0", "2.0"],
|
||||
updated_at=datetime.now(),
|
||||
is_available=True,
|
||||
)
|
||||
|
||||
# Create a mock StoreListing result
|
||||
mock_store_listing = mocker.MagicMock()
|
||||
mock_store_listing.activeVersionId = "active-version-id"
|
||||
mock_store_listing.hasApprovedVersion = True
|
||||
mock_store_listing.ActiveVersion = mocker.MagicMock()
|
||||
mock_store_listing.ActiveVersion.recommendedScheduleCron = None
|
||||
|
||||
# Mock StoreAgent prisma call
|
||||
# Mock StoreAgent prisma call - need to handle multiple calls
|
||||
mock_store_agent = mocker.patch("prisma.models.StoreAgent.prisma")
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(return_value=mock_agent)
|
||||
|
||||
# Set up side_effect to return different results for different calls
|
||||
def mock_find_first_side_effect(*args, **kwargs):
|
||||
where_clause = kwargs.get("where", {})
|
||||
if "storeListingVersionId" in where_clause:
|
||||
# Second call for active version
|
||||
return mock_active_agent
|
||||
else:
|
||||
# First call for initial lookup
|
||||
return mock_agent
|
||||
|
||||
mock_store_agent.return_value.find_first = mocker.AsyncMock(
|
||||
side_effect=mock_find_first_side_effect
|
||||
)
|
||||
|
||||
# Mock Profile prisma call
|
||||
mock_profile = mocker.MagicMock()
|
||||
@@ -103,7 +139,7 @@ async def test_get_store_agent_details(mocker):
|
||||
return_value=mock_profile
|
||||
)
|
||||
|
||||
# Mock StoreListing prisma call - this is what was missing
|
||||
# Mock StoreListing prisma call
|
||||
mock_store_listing_db = mocker.patch("prisma.models.StoreListing.prisma")
|
||||
mock_store_listing_db.return_value.find_first = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
@@ -112,16 +148,25 @@ async def test_get_store_agent_details(mocker):
|
||||
# Call function
|
||||
result = await db.get_store_agent_details("creator", "test-agent")
|
||||
|
||||
# Verify results
|
||||
# Verify results - should use active version data
|
||||
assert result.slug == "test-agent"
|
||||
assert result.agent_name == "Test Agent"
|
||||
assert result.agent_name == "Test Agent Active" # From active version
|
||||
assert result.active_version_id == "active-version-id"
|
||||
assert result.has_approved_version is True
|
||||
assert (
|
||||
result.store_listing_version_id == "active-version-id"
|
||||
) # Should be active version ID
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_agent.return_value.find_first.assert_called_once_with(
|
||||
# Verify mocks called correctly - now expecting 2 calls
|
||||
assert mock_store_agent.return_value.find_first.call_count == 2
|
||||
|
||||
# Check the specific calls
|
||||
calls = mock_store_agent.return_value.find_first.call_args_list
|
||||
assert calls[0] == mocker.call(
|
||||
where={"creator_username": "creator", "slug": "test-agent"}
|
||||
)
|
||||
assert calls[1] == mocker.call(where={"storeListingVersionId": "active-version-id"})
|
||||
|
||||
mock_store_listing_db.return_value.find_first.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class MyAgent(pydantic.BaseModel):
|
||||
agent_image: str | None = None
|
||||
description: str
|
||||
last_edited: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class MyAgentsResponse(pydantic.BaseModel):
|
||||
@@ -48,11 +49,13 @@ class StoreAgentDetails(pydantic.BaseModel):
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
categories: list[str]
|
||||
runs: int
|
||||
rating: float
|
||||
versions: list[str]
|
||||
last_updated: datetime.datetime
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
active_version_id: str | None = None
|
||||
has_approved_version: bool = False
|
||||
@@ -101,6 +104,7 @@ class StoreSubmission(pydantic.BaseModel):
|
||||
sub_heading: str
|
||||
slug: str
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
image_urls: list[str]
|
||||
date_submitted: datetime.datetime
|
||||
status: prisma.enums.SubmissionStatus
|
||||
@@ -155,8 +159,10 @@ class StoreSubmissionRequest(pydantic.BaseModel):
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
instructions: str | None = None
|
||||
categories: list[str] = []
|
||||
changes_summary: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
@@ -165,8 +171,10 @@ class StoreSubmissionEditRequest(pydantic.BaseModel):
|
||||
video_url: str | None = None
|
||||
image_urls: list[str] = []
|
||||
description: str = ""
|
||||
instructions: str | None = None
|
||||
categories: list[str] = []
|
||||
changes_summary: str | None = None
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
|
||||
class ProfileDetails(pydantic.BaseModel):
|
||||
|
||||
@@ -532,9 +532,11 @@ async def create_submission(
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary or "Initial Submission",
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
@@ -577,9 +579,11 @@ async def edit_submission(
|
||||
video_url=submission_request.video_url,
|
||||
image_urls=submission_request.image_urls,
|
||||
description=submission_request.description,
|
||||
instructions=submission_request.instructions,
|
||||
sub_heading=submission_request.sub_heading,
|
||||
categories=submission_request.categories,
|
||||
changes_summary=submission_request.changes_summary,
|
||||
recommended_schedule_cron=submission_request.recommended_schedule_cron,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
update_websocket_connections,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import (
|
||||
WSMessage,
|
||||
@@ -38,6 +42,15 @@ docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(lifespan=lifespan, docs_url=docs_url)
|
||||
_connection_manager = None
|
||||
|
||||
# Add Prometheus instrumentation
|
||||
instrument_fastapi(
|
||||
app,
|
||||
service_name="websocket-server",
|
||||
expose_endpoint=True,
|
||||
endpoint="/metrics",
|
||||
include_in_schema=settings.config.app_env == AppEnvironment.LOCAL,
|
||||
)
|
||||
|
||||
|
||||
def get_connection_manager():
|
||||
global _connection_manager
|
||||
@@ -216,6 +229,10 @@ async def websocket_router(
|
||||
if not user_id:
|
||||
return
|
||||
await manager.connect_socket(websocket)
|
||||
|
||||
# Track WebSocket connection
|
||||
update_websocket_connections(user_id, 1)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
@@ -286,6 +303,8 @@ async def websocket_router(
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
finally:
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -33,7 +33,7 @@ def sentry_init():
|
||||
)
|
||||
|
||||
|
||||
def sentry_capture_error(error: Exception):
|
||||
def sentry_capture_error(error: BaseException):
|
||||
sentry_sdk.capture_exception(error)
|
||||
sentry_sdk.flush()
|
||||
|
||||
|
||||
@@ -76,6 +76,14 @@ class AppProcess(ABC):
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
)
|
||||
# Send error to Sentry before cleanup
|
||||
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
try:
|
||||
from backend.util.metrics import sentry_capture_error
|
||||
|
||||
sentry_capture_error(e)
|
||||
except Exception:
|
||||
pass # Silently ignore if Sentry isn't available
|
||||
finally:
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
|
||||
@@ -479,6 +479,9 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
)
|
||||
|
||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||
openai_internal_api_key: str = Field(
|
||||
default="", description="OpenAI Internal API key"
|
||||
)
|
||||
aiml_api_key: str = Field(default="", description="'AI/ML API' key")
|
||||
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
||||
groq_api_key: str = Field(default="", description="Groq API key")
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
-- These changes are part of improvements to our API key system.
|
||||
-- See https://github.com/Significant-Gravitas/AutoGPT/pull/10796 for context.
|
||||
|
||||
-- Add 'salt' column for Scrypt hashing
|
||||
ALTER TABLE "APIKey" ADD COLUMN "salt" TEXT;
|
||||
|
||||
-- Rename columns for clarity
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "key" TO "hash";
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "prefix" TO "head";
|
||||
ALTER TABLE "APIKey" RENAME COLUMN "postfix" TO "tail";
|
||||
@@ -0,0 +1,3 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "recommendedScheduleCron" TEXT;
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "recommendedScheduleCron" TEXT;
|
||||
@@ -0,0 +1,66 @@
|
||||
-- Fixes the refresh function+job introduced in 20250604130249_optimise_store_agent_and_creator_views
|
||||
-- by improving the function to accept a schema parameter and updating the cron job to use it.
|
||||
-- This resolves the issue where pg_cron jobs fail because they run in 'public' schema
|
||||
-- but the materialized views exist in 'platform' schema.
|
||||
|
||||
|
||||
-- Create parameterized refresh function that accepts schema name
|
||||
CREATE OR REPLACE FUNCTION refresh_store_materialized_views()
|
||||
RETURNS void
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
target_schema text := current_schema(); -- Use the current schema where the function is called
|
||||
BEGIN
|
||||
-- Use CONCURRENTLY for better performance during refresh
|
||||
REFRESH MATERIALIZED VIEW CONCURRENTLY "mv_agent_run_counts";
|
||||
REFRESH MATERIALIZED VIEW CONCURRENTLY "mv_review_stats";
|
||||
RAISE NOTICE 'Materialized views refreshed in schema % at %', target_schema, NOW();
|
||||
EXCEPTION
|
||||
WHEN OTHERS THEN
|
||||
-- Fallback to non-concurrent refresh if concurrent fails
|
||||
REFRESH MATERIALIZED VIEW "mv_agent_run_counts";
|
||||
REFRESH MATERIALIZED VIEW "mv_review_stats";
|
||||
RAISE NOTICE 'Materialized views refreshed (non-concurrent) in schema % at %. Concurrent refresh failed due to: %', target_schema, NOW(), SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
-- Initial refresh + test of the function to ensure it works
|
||||
SELECT refresh_store_materialized_views();
|
||||
|
||||
-- Re-create the cron job to use the improved function
|
||||
DO $$
|
||||
DECLARE
|
||||
has_pg_cron BOOLEAN;
|
||||
current_schema_name text := current_schema();
|
||||
old_job_name text;
|
||||
job_name text;
|
||||
BEGIN
|
||||
-- Check if pg_cron extension exists
|
||||
SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_cron') INTO has_pg_cron;
|
||||
|
||||
IF has_pg_cron THEN
|
||||
old_job_name := format('refresh-store-views-%s', current_schema_name);
|
||||
job_name := format('refresh-store-views_%s', current_schema_name);
|
||||
|
||||
-- Try to unschedule existing job (ignore errors if it doesn't exist)
|
||||
BEGIN
|
||||
PERFORM cron.unschedule(old_job_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
NULL;
|
||||
END;
|
||||
|
||||
-- Schedule the new job with explicit schema parameter
|
||||
PERFORM cron.schedule(
|
||||
job_name,
|
||||
'*/15 * * * *',
|
||||
format('SET search_path TO %I; SELECT refresh_store_materialized_views();', current_schema_name)
|
||||
);
|
||||
RAISE NOTICE 'Scheduled job %; runs every 15 minutes for schema %', job_name, current_schema_name;
|
||||
ELSE
|
||||
RAISE WARNING '⚠️ Automatic refresh NOT configured - pg_cron is not available';
|
||||
RAISE WARNING '⚠️ You must manually refresh views with: SELECT refresh_store_materialized_views();';
|
||||
RAISE WARNING '⚠️ Or install pg_cron for automatic refresh in production';
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Re-create foreign key CreditTransaction <- User with ON DELETE NO ACTION
|
||||
ALTER TABLE "CreditTransaction" DROP CONSTRAINT "CreditTransaction_userId_fkey";
|
||||
ALTER TABLE "CreditTransaction" ADD CONSTRAINT "CreditTransaction_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE NO ACTION ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- A unique constraint covering the columns `[shareToken]` on the table `AgentGraphExecution` will be added. If there are existing duplicate values, this will fail.
|
||||
|
||||
*/
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "isShared" BOOLEAN NOT NULL DEFAULT false,
|
||||
ADD COLUMN "shareToken" TEXT,
|
||||
ADD COLUMN "sharedAt" TIMESTAMP(3);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "AgentGraphExecution_shareToken_key" ON "AgentGraphExecution"("shareToken");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "AgentGraphExecution_shareToken_idx" ON "AgentGraphExecution"("shareToken");
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "APIKey_key_key" RENAME TO "APIKey_hash_key";
|
||||
|
||||
-- RenameIndex
|
||||
ALTER INDEX "APIKey_prefix_name_idx" RENAME TO "APIKey_head_name_idx";
|
||||
@@ -0,0 +1,53 @@
|
||||
-- Add instructions field to AgentGraph and StoreListingVersion tables and update StoreSubmission view
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- AddColumn
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "instructions" TEXT;
|
||||
|
||||
-- AddColumn
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "instructions" TEXT;
|
||||
|
||||
-- Drop the existing view
|
||||
DROP VIEW IF EXISTS "StoreSubmission";
|
||||
|
||||
-- Recreate the view with the new instructions field
|
||||
CREATE VIEW "StoreSubmission" AS
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
sl."owningUserId" AS user_id,
|
||||
slv."agentGraphId" AS agent_id,
|
||||
slv.version AS agent_version,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS name,
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.instructions,
|
||||
slv."imageUrls" AS image_urls,
|
||||
slv."submittedAt" AS date_submitted,
|
||||
slv."submissionStatus" AS status,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(avg(sr.score::numeric), 0.0)::double precision AS rating,
|
||||
slv.id AS store_listing_version_id,
|
||||
slv."reviewerId" AS reviewer_id,
|
||||
slv."reviewComments" AS review_comments,
|
||||
slv."internalComments" AS internal_comments,
|
||||
slv."reviewedAt" AS reviewed_at,
|
||||
slv."changesSummary" AS changes_summary,
|
||||
slv."videoUrl" AS video_url,
|
||||
slv.categories
|
||||
FROM "StoreListing" sl
|
||||
JOIN "StoreListingVersion" slv ON slv."storeListingId" = sl.id
|
||||
LEFT JOIN "StoreListingReview" sr ON sr."storeListingVersionId" = slv.id
|
||||
LEFT JOIN (
|
||||
SELECT "AgentGraphExecution"."agentGraphId", count(*) AS run_count
|
||||
FROM "AgentGraphExecution"
|
||||
GROUP BY "AgentGraphExecution"."agentGraphId"
|
||||
) ar ON ar."agentGraphId" = slv."agentGraphId"
|
||||
WHERE sl."isDeleted" = false
|
||||
GROUP BY sl.id, sl."owningUserId", slv.id, slv."agentGraphId", slv.version, sl.slug, slv.name,
|
||||
slv."subHeading", slv.description, slv.instructions, slv."imageUrls", slv."submittedAt",
|
||||
slv."submissionStatus", slv."reviewerId", slv."reviewComments", slv."internalComments",
|
||||
slv."reviewedAt", slv."changesSummary", slv."videoUrl", slv.categories, ar.run_count;
|
||||
|
||||
COMMIT;
|
||||
19
autogpt_platform/backend/poetry.lock
generated
@@ -403,6 +403,7 @@ develop = true
|
||||
|
||||
[package.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^45.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.116.1"
|
||||
google-cloud-logging = "^3.12.1"
|
||||
@@ -4144,6 +4145,22 @@ files = [
|
||||
[package.extras]
|
||||
twisted = ["twisted"]
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-fastapi-instrumentator"
|
||||
version = "7.1.0"
|
||||
description = "Instrument your FastAPI app with Prometheus metrics"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "prometheus_fastapi_instrumentator-7.1.0-py3-none-any.whl", hash = "sha256:978130f3c0bb7b8ebcc90d35516a6fe13e02d2eb358c8f83887cdef7020c31e9"},
|
||||
{file = "prometheus_fastapi_instrumentator-7.1.0.tar.gz", hash = "sha256:be7cd61eeea4e5912aeccb4261c6631b3f227d8924542d79eaf5af3f439cbe5e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
prometheus-client = ">=0.8.0,<1.0.0"
|
||||
starlette = ">=0.30.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "propcache"
|
||||
version = "0.3.2"
|
||||
@@ -7142,4 +7159,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "80d4dc2cbcd1ae33b2fa3920db5dcb1f82ad252d1e4a8bfeba8b2f2eebbdda0d"
|
||||
content-hash = "2c7e9370f500039b99868376021627c5a120e0ee31c5c5e6de39db2c3d82f414"
|
||||
|
||||
@@ -45,6 +45,7 @@ postmarker = "^1.0"
|
||||
praw = "~7.8.1"
|
||||
prisma = "^0.15.0"
|
||||
prometheus-client = "^0.22.1"
|
||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||
psutil = "^7.0.0"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
||||
|
||||
@@ -110,6 +110,8 @@ model AgentGraph {
|
||||
|
||||
name String?
|
||||
description String?
|
||||
instructions String?
|
||||
recommendedScheduleCron String?
|
||||
|
||||
isActive Boolean @default(true)
|
||||
|
||||
@@ -369,10 +371,16 @@ model AgentGraphExecution {
|
||||
|
||||
stats Json?
|
||||
|
||||
// Sharing fields
|
||||
isShared Boolean @default(false)
|
||||
shareToken String? @unique
|
||||
sharedAt DateTime?
|
||||
|
||||
@@index([agentGraphId, agentGraphVersion])
|
||||
@@index([userId])
|
||||
@@index([createdAt])
|
||||
@@index([agentPresetId])
|
||||
@@index([shareToken])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -527,7 +535,7 @@ model CreditTransaction {
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
User User? @relation(fields: [userId], references: [id], onDelete: NoAction)
|
||||
|
||||
amount Int
|
||||
type CreditTransactionType
|
||||
@@ -756,6 +764,7 @@ model StoreListingVersion {
|
||||
videoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
@@ -785,6 +794,8 @@ model StoreListingVersion {
|
||||
reviewComments String? // Comments visible to creator
|
||||
reviewedAt DateTime?
|
||||
|
||||
recommendedScheduleCron String? // cron expression like "0 9 * * *"
|
||||
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
@@ -828,11 +839,13 @@ enum APIKeyPermission {
|
||||
}
|
||||
|
||||
model APIKey {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
prefix String // First 8 chars for identification
|
||||
postfix String
|
||||
key String @unique // Hashed key
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
head String // First few chars for identification
|
||||
tail String
|
||||
hash String @unique
|
||||
salt String? // null for legacy unsalted keys
|
||||
|
||||
status APIKeyStatus @default(ACTIVE)
|
||||
permissions APIKeyPermission[]
|
||||
|
||||
@@ -846,7 +859,7 @@ model APIKey {
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([prefix, name])
|
||||
@@index([head, name])
|
||||
@@index([userId, status])
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"creator_avatar": "avatar1.jpg",
|
||||
"sub_heading": "Test agent subheading",
|
||||
"description": "Test agent description",
|
||||
"instructions": null,
|
||||
"categories": [
|
||||
"category1",
|
||||
"category2"
|
||||
@@ -22,6 +23,7 @@
|
||||
"1.1.0"
|
||||
],
|
||||
"last_updated": "2023-01-01T00:00:00",
|
||||
"recommended_schedule_cron": null,
|
||||
"active_version_id": null,
|
||||
"has_approved_version": false
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"created_at": "2025-09-04T13:37:00",
|
||||
"credentials_input_schema": {
|
||||
"properties": {},
|
||||
"title": "TestGraphCredentialsInputSchema",
|
||||
@@ -14,6 +15,7 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"instructions": null,
|
||||
"is_active": true,
|
||||
"links": [],
|
||||
"name": "Test Graph",
|
||||
@@ -23,6 +25,7 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "test-user-id",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"instructions": null,
|
||||
"is_active": true,
|
||||
"name": "Test Graph",
|
||||
"output_schema": {
|
||||
@@ -22,6 +23,7 @@
|
||||
"required": [],
|
||||
"type": "object"
|
||||
},
|
||||
"recommended_schedule_cron": null,
|
||||
"sub_graphs": [],
|
||||
"trigger_setup_info": null,
|
||||
"user_id": "test-user-id",
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"updated_at": "2023-01-01T00:00:00",
|
||||
"name": "Test Agent 1",
|
||||
"description": "Test Description 1",
|
||||
"instructions": null,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
@@ -27,7 +28,9 @@
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
"recommended_schedule_cron": null
|
||||
},
|
||||
{
|
||||
"id": "test-agent-2",
|
||||
@@ -40,6 +43,7 @@
|
||||
"updated_at": "2023-01-01T00:00:00",
|
||||
"name": "Test Agent 2",
|
||||
"description": "Test Description 2",
|
||||
"instructions": null,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
@@ -56,7 +60,9 @@
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
"recommended_schedule_cron": null
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"sub_heading": "Test agent subheading",
|
||||
"slug": "test-agent",
|
||||
"description": "Test agent description",
|
||||
"instructions": null,
|
||||
"image_urls": [
|
||||
"test.jpg"
|
||||
],
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from faker import Faker
|
||||
|
||||
from backend.data.api_key import generate_api_key
|
||||
from backend.data.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
@@ -466,7 +466,7 @@ class TestDataCreator:
|
||||
|
||||
try:
|
||||
# Use the API function to create API key
|
||||
api_key, _ = await generate_api_key(
|
||||
api_key, _ = await create_api_key(
|
||||
name=faker.word(),
|
||||
user_id=user["id"],
|
||||
permissions=[
|
||||
|
||||
@@ -146,16 +146,23 @@ class TestAutoRegistry:
|
||||
"""Test API key environment variable registration."""
|
||||
import os
|
||||
|
||||
from backend.sdk.builder import ProviderBuilder
|
||||
|
||||
# Set up a test environment variable
|
||||
os.environ["TEST_API_KEY"] = "test-api-key-value"
|
||||
|
||||
try:
|
||||
AutoRegistry.register_api_key("test_provider", "TEST_API_KEY")
|
||||
# Use ProviderBuilder which calls register_api_key and creates the credential
|
||||
provider = (
|
||||
ProviderBuilder("test_provider")
|
||||
.with_api_key("TEST_API_KEY", "Test API Key")
|
||||
.build()
|
||||
)
|
||||
|
||||
# Verify the mapping is stored
|
||||
assert AutoRegistry._api_key_mappings["test_provider"] == "TEST_API_KEY"
|
||||
|
||||
# Verify a credential was created
|
||||
# Verify a credential was created through the provider
|
||||
all_creds = AutoRegistry.get_all_credentials()
|
||||
test_cred = next(
|
||||
(c for c in all_creds if c.id == "test_provider-default"), None
|
||||
@@ -370,7 +377,7 @@ class TestProviderBuilder:
|
||||
|
||||
def test_provider_builder_with_base_cost(self):
|
||||
"""Test building a provider with base costs."""
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.block import BlockCostType
|
||||
|
||||
provider = (
|
||||
ProviderBuilder("cost_test")
|
||||
@@ -411,7 +418,7 @@ class TestProviderBuilder:
|
||||
|
||||
def test_provider_builder_complete_example(self):
|
||||
"""Test building a complete provider with all features."""
|
||||
from backend.data.cost import BlockCostType
|
||||
from backend.data.block import BlockCostType
|
||||
|
||||
class TestOAuth(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.GITHUB
|
||||
|
||||
@@ -21,6 +21,7 @@ import random
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.enums
|
||||
from autogpt_libs.api_key.keysmith import APIKeySmith
|
||||
from faker import Faker
|
||||
from prisma import Json, Prisma
|
||||
from prisma.types import (
|
||||
@@ -30,7 +31,6 @@ from prisma.types import (
|
||||
AgentNodeLinkCreateInput,
|
||||
AnalyticsDetailsCreateInput,
|
||||
AnalyticsMetricsCreateInput,
|
||||
APIKeyCreateInput,
|
||||
CreditTransactionCreateInput,
|
||||
IntegrationWebhookCreateInput,
|
||||
ProfileCreateInput,
|
||||
@@ -544,20 +544,22 @@ async def main():
|
||||
# Insert APIKeys
|
||||
print(f"Inserting {NUM_USERS} api keys")
|
||||
for user in users:
|
||||
api_key = APIKeySmith().generate_key()
|
||||
await db.apikey.create(
|
||||
data=APIKeyCreateInput(
|
||||
name=faker.word(),
|
||||
prefix=str(faker.uuid4())[:8],
|
||||
postfix=str(faker.uuid4())[-8:],
|
||||
key=str(faker.sha256()),
|
||||
status=prisma.enums.APIKeyStatus.ACTIVE,
|
||||
permissions=[
|
||||
data={
|
||||
"name": faker.word(),
|
||||
"head": api_key.head,
|
||||
"tail": api_key.tail,
|
||||
"hash": api_key.hash,
|
||||
"salt": api_key.salt,
|
||||
"status": prisma.enums.APIKeyStatus.ACTIVE,
|
||||
"permissions": [
|
||||
prisma.enums.APIKeyPermission.EXECUTE_GRAPH,
|
||||
prisma.enums.APIKeyPermission.READ_GRAPH,
|
||||
],
|
||||
description=faker.text(),
|
||||
userId=user.id,
|
||||
)
|
||||
"description": faker.text(),
|
||||
"userId": user.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Refresh materialized views
|
||||
|
||||
@@ -65,7 +65,6 @@ services:
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
command: redis-server --requirepass password
|
||||
ports:
|
||||
- "6379:6379"
|
||||
networks:
|
||||
|
||||
@@ -4,3 +4,5 @@ pnpm-lock.yaml
|
||||
.auth
|
||||
build
|
||||
public
|
||||
Dockerfile
|
||||
.prettierignore
|
||||
|
||||
@@ -12,14 +12,21 @@ COPY autogpt_platform/frontend/ .
|
||||
# Allow CI to opt-in to Playwright test build-time flags
|
||||
ARG NEXT_PUBLIC_PW_TEST="false"
|
||||
ENV NEXT_PUBLIC_PW_TEST=$NEXT_PUBLIC_PW_TEST
|
||||
RUN if [ -f .env ]; then \
|
||||
ENV NODE_ENV="production"
|
||||
# Merge env files appropriately based on environment
|
||||
RUN if [ -f .env.production ]; then \
|
||||
# In CI/CD: merge defaults with production (production takes precedence)
|
||||
cat .env.default .env.production > .env.merged && mv .env.merged .env.production; \
|
||||
elif [ -f .env ]; then \
|
||||
# Local with custom .env: merge defaults with .env
|
||||
cat .env.default .env > .env.merged && mv .env.merged .env; \
|
||||
else \
|
||||
# Local without custom .env: use defaults
|
||||
cp .env.default .env; \
|
||||
fi
|
||||
RUN pnpm run generate:api
|
||||
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true pnpm build; else pnpm build; fi
|
||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
|
||||
|
||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||
FROM node:21-alpine AS prod
|
||||
|
||||
@@ -2,18 +2,28 @@
|
||||
// The config you add here will be used whenever a users loads a page in their browser.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import { BehaveAs, getBehaveAs, getEnvironmentStr } from "@/lib/utils";
|
||||
import {
|
||||
AppEnv,
|
||||
BehaveAs,
|
||||
getAppEnv,
|
||||
getBehaveAs,
|
||||
getEnvironmentStr,
|
||||
} from "@/lib/utils";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
const isProductionCloud =
|
||||
process.env.NODE_ENV === "production" && getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isProdOrDev = [AppEnv.PROD, AppEnv.DEV].includes(getAppEnv());
|
||||
|
||||
const isCloud = getBehaveAs() === BehaveAs.CLOUD;
|
||||
const isDisabled = process.env.DISABLE_SENTRY === "true";
|
||||
|
||||
const shouldEnable = !isDisabled && isProdOrDev && isCloud;
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
environment: getEnvironmentStr(),
|
||||
|
||||
enabled: isProductionCloud,
|
||||
enabled: shouldEnable,
|
||||
|
||||
// Add optional integrations for additional features
|
||||
integrations: [
|
||||
@@ -56,10 +66,7 @@ Sentry.init({
|
||||
// For example, a tracesSampleRate of 0.5 and profilesSampleRate of 0.5 would
|
||||
// result in 25% of transactions being profiled (0.5*0.5=0.25)
|
||||
profilesSampleRate: 1.0,
|
||||
_experiments: {
|
||||
// Enable logs to be sent to Sentry.
|
||||
enableLogs: true,
|
||||
},
|
||||
enableLogs: true,
|
||||
});
|
||||
|
||||
export const onRouterTransitionStart = Sentry.captureRouterTransitionStart;
|
||||
|
||||
@@ -12,6 +12,23 @@ const nextConfig = {
|
||||
"ideogram.ai", // for generated images
|
||||
"picsum.photos", // for placeholder images
|
||||
],
|
||||
remotePatterns: [
|
||||
{
|
||||
protocol: "https",
|
||||
hostname: "storage.googleapis.com",
|
||||
pathname: "/**",
|
||||
},
|
||||
{
|
||||
protocol: "https",
|
||||
hostname: "storage.cloud.google.com",
|
||||
pathname: "/**",
|
||||
},
|
||||
{
|
||||
protocol: "https",
|
||||
hostname: "lh3.googleusercontent.com",
|
||||
pathname: "/**",
|
||||
},
|
||||
],
|
||||
},
|
||||
output: "standalone",
|
||||
transpilePackages: ["geist"],
|
||||
|
||||
@@ -35,12 +35,38 @@ export default defineConfig({
|
||||
useInfiniteQueryParam: "page",
|
||||
},
|
||||
},
|
||||
"getV2List favorite library agents": {
|
||||
query: {
|
||||
useInfinite: true,
|
||||
useInfiniteQueryParam: "page",
|
||||
},
|
||||
},
|
||||
"getV1List graph executions": {
|
||||
query: {
|
||||
useInfinite: true,
|
||||
useInfiniteQueryParam: "page",
|
||||
},
|
||||
},
|
||||
"getV2Get builder blocks": {
|
||||
query: {
|
||||
useInfinite: true,
|
||||
useInfiniteQueryParam: "page",
|
||||
useQuery: true,
|
||||
},
|
||||
},
|
||||
"getV2Get builder integration providers": {
|
||||
query: {
|
||||
useInfinite: true,
|
||||
useInfiniteQueryParam: "page",
|
||||
},
|
||||
},
|
||||
"getV2List store agents": {
|
||||
query: {
|
||||
useInfinite: true,
|
||||
useInfiniteQueryParam: "page",
|
||||
useQuery: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -54,6 +54,8 @@
|
||||
"@tanstack/react-query": "5.85.3",
|
||||
"@tanstack/react-table": "8.21.3",
|
||||
"@types/jaro-winkler": "0.2.4",
|
||||
"@vercel/analytics": "1.5.0",
|
||||
"@vercel/speed-insights": "1.2.0",
|
||||
"@xyflow/react": "12.8.3",
|
||||
"boring-avatars": "1.11.2",
|
||||
"class-variance-authority": "0.7.1",
|
||||
@@ -104,32 +106,32 @@
|
||||
"zod": "3.25.76"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "4.1.0",
|
||||
"@playwright/test": "1.54.2",
|
||||
"@storybook/addon-a11y": "9.1.2",
|
||||
"@storybook/addon-docs": "9.1.2",
|
||||
"@storybook/addon-links": "9.1.2",
|
||||
"@storybook/addon-onboarding": "9.1.2",
|
||||
"@storybook/nextjs": "9.1.2",
|
||||
"@tanstack/eslint-plugin-query": "5.83.1",
|
||||
"@tanstack/react-query-devtools": "5.84.2",
|
||||
"@chromatic-com/storybook": "4.1.1",
|
||||
"@playwright/test": "1.55.0",
|
||||
"@storybook/addon-a11y": "9.1.5",
|
||||
"@storybook/addon-docs": "9.1.5",
|
||||
"@storybook/addon-links": "9.1.5",
|
||||
"@storybook/addon-onboarding": "9.1.5",
|
||||
"@storybook/nextjs": "9.1.5",
|
||||
"@tanstack/eslint-plugin-query": "5.86.0",
|
||||
"@tanstack/react-query-devtools": "5.87.3",
|
||||
"@types/canvas-confetti": "1.9.0",
|
||||
"@types/lodash": "4.17.20",
|
||||
"@types/negotiator": "0.6.4",
|
||||
"@types/node": "24.2.1",
|
||||
"@types/node": "24.3.1",
|
||||
"@types/react": "18.3.17",
|
||||
"@types/react-dom": "18.3.5",
|
||||
"@types/react-modal": "3.16.3",
|
||||
"@types/react-window": "1.8.8",
|
||||
"axe-playwright": "2.1.0",
|
||||
"chromatic": "13.1.3",
|
||||
"concurrently": "9.2.0",
|
||||
"chromatic": "13.1.4",
|
||||
"concurrently": "9.2.1",
|
||||
"cross-env": "7.0.3",
|
||||
"eslint": "8.57.1",
|
||||
"eslint-config-next": "15.4.6",
|
||||
"eslint-plugin-storybook": "9.1.2",
|
||||
"eslint-config-next": "15.5.2",
|
||||
"eslint-plugin-storybook": "9.1.5",
|
||||
"import-in-the-middle": "1.14.2",
|
||||
"msw": "2.10.4",
|
||||
"msw": "2.11.1",
|
||||
"msw-storybook-addon": "2.0.5",
|
||||
"orval": "7.11.2",
|
||||
"pbkdf2": "3.1.3",
|
||||
@@ -137,7 +139,7 @@
|
||||
"prettier": "3.6.2",
|
||||
"prettier-plugin-tailwindcss": "0.6.14",
|
||||
"require-in-the-middle": "7.5.2",
|
||||
"storybook": "9.1.2",
|
||||
"storybook": "9.1.5",
|
||||
"tailwindcss": "3.4.17",
|
||||
"typescript": "5.9.2"
|
||||
},
|
||||
|
||||
1773
autogpt_platform/frontend/pnpm-lock.yaml
generated
BIN
autogpt_platform/frontend/public/integrations/aiml_api.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
autogpt_platform/frontend/public/integrations/airtable.png
Normal file
|
After Width: | Height: | Size: 1.7 KiB |
BIN
autogpt_platform/frontend/public/integrations/anthropic.png
Normal file
|
After Width: | Height: | Size: 10 KiB |
BIN
autogpt_platform/frontend/public/integrations/apollo.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
BIN
autogpt_platform/frontend/public/integrations/baas.png
Normal file
|
After Width: | Height: | Size: 2.7 KiB |
BIN
autogpt_platform/frontend/public/integrations/bannerbear.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
|
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 4.2 KiB |
BIN
autogpt_platform/frontend/public/integrations/d_id.png
Normal file
|
After Width: | Height: | Size: 2.6 KiB |
BIN
autogpt_platform/frontend/public/integrations/dataforseo.png
Normal file
|
After Width: | Height: | Size: 14 KiB |
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 5.1 KiB |
BIN
autogpt_platform/frontend/public/integrations/e2b.png
Normal file
|
After Width: | Height: | Size: 7.2 KiB |