mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 08:14:58 -05:00
Compare commits
31 Commits
fix/copilo
...
pwuts/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cabda535ea | ||
|
|
43b25b5e2f | ||
|
|
ab0b537cc7 | ||
|
|
9a8c6ad609 | ||
|
|
746a36822d | ||
|
|
2a46d3fbf4 | ||
|
|
ab25516a46 | ||
|
|
6e2f595c7d | ||
|
|
e523eb62b5 | ||
|
|
97ff65ef6a | ||
|
|
e8b81f71ef | ||
|
|
d652821ed5 | ||
|
|
80659d90e4 | ||
|
|
eef892893c | ||
|
|
23175708e6 | ||
|
|
f02c00374e | ||
|
|
2fa166d839 | ||
|
|
d927e4b611 | ||
|
|
6591b2171c | ||
|
|
85d97a9d5c | ||
|
|
16c8b2a6e3 | ||
|
|
cad54a9f3e | ||
|
|
ca0620b102 | ||
|
|
7a4cf4e186 | ||
|
|
fe9debd80f | ||
|
|
7083dcf226 | ||
|
|
ee2805d14c | ||
|
|
f15362d619 | ||
|
|
6c2374593f | ||
|
|
0f4c33308f | ||
|
|
ecb9fdae25 |
@@ -5,42 +5,13 @@
|
||||
!docs/
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
!autogpt_platform/autogpt_libs/poetry.lock
|
||||
!autogpt_platform/autogpt_libs/README.md
|
||||
!autogpt_platform/autogpt_libs/
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/test/e2e_test_data.py
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
!autogpt_platform/backend/.env
|
||||
!autogpt_platform/backend/gen_prisma_types_stub.py
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
!autogpt_platform/market/scripts.py
|
||||
!autogpt_platform/market/schema.prisma
|
||||
!autogpt_platform/market/pyproject.toml
|
||||
!autogpt_platform/market/poetry.lock
|
||||
!autogpt_platform/market/README.md
|
||||
!autogpt_platform/backend/
|
||||
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/scripts/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/pnpm-lock.yaml
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
!autogpt_platform/frontend/README.md
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
!autogpt_platform/frontend/.env
|
||||
!autogpt_platform/frontend/
|
||||
|
||||
# Classic - AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
@@ -64,6 +35,38 @@
|
||||
# Classic - Frontend
|
||||
!classic/frontend/build/web/
|
||||
|
||||
# Explicitly re-ignore some folders
|
||||
.*
|
||||
**/__pycache__
|
||||
# Explicitly re-ignore unwanted files from whitelisted directories
|
||||
# Note: These patterns MUST come after the whitelist rules to take effect
|
||||
|
||||
# Hidden files and directories (but keep frontend .env files needed for build)
|
||||
**/.*
|
||||
!autogpt_platform/frontend/.env
|
||||
!autogpt_platform/frontend/.env.default
|
||||
!autogpt_platform/frontend/.env.production
|
||||
|
||||
# Python artifacts
|
||||
**/__pycache__/
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/.venv/
|
||||
**/.ruff_cache/
|
||||
**/.pytest_cache/
|
||||
**/.coverage
|
||||
**/htmlcov/
|
||||
|
||||
# Node artifacts
|
||||
**/node_modules/
|
||||
**/.next/
|
||||
**/storybook-static/
|
||||
**/playwright-report/
|
||||
**/test-results/
|
||||
|
||||
# Build artifacts
|
||||
**/dist/
|
||||
**/build/
|
||||
!autogpt_platform/frontend/src/**/build/
|
||||
**/target/
|
||||
|
||||
# Logs and temp files
|
||||
**/*.log
|
||||
**/*.tmp
|
||||
|
||||
241
.github/workflows/platform-frontend-ci.yml
vendored
241
.github/workflows/platform-frontend-ci.yml
vendored
@@ -26,7 +26,6 @@ jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||
components-changed: ${{ steps.filter.outputs.components }}
|
||||
|
||||
steps:
|
||||
@@ -41,28 +40,17 @@ jobs:
|
||||
components:
|
||||
- 'autogpt_platform/frontend/src/components/**'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Generate cache key
|
||||
id: cache-key
|
||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ steps.cache-key.outputs.key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies to populate cache
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
lint:
|
||||
@@ -73,22 +61,15 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -111,22 +92,15 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
@@ -141,10 +115,8 @@ jobs:
|
||||
exitOnceUploaded: true
|
||||
|
||||
e2e_test:
|
||||
name: end-to-end tests
|
||||
runs-on: big-boi
|
||||
needs: setup
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -152,19 +124,11 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Copy default supabase .env
|
||||
- name: Set up Platform - Copy default supabase .env
|
||||
run: |
|
||||
cp ../.env.default ../.env
|
||||
|
||||
- name: Copy backend .env and set OpenAI API key
|
||||
- name: Set up Platform - Copy backend .env and set OpenAI API key
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
@@ -172,77 +136,125 @@ jobs:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
- name: Set up Platform - Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver: docker-container
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Cache Docker layers
|
||||
- name: Set up Platform - Expose GHA cache to docker buildx CLI
|
||||
uses: crazy-max/ghaction-github-runtime@v3
|
||||
|
||||
- name: Set up Platform - Build Docker images (with cache)
|
||||
working-directory: autogpt_platform
|
||||
run: |
|
||||
pip install pyyaml
|
||||
|
||||
# Resolve extends and generate a flat compose file that bake can understand
|
||||
docker compose -f docker-compose.yml config > docker-compose.resolved.yml
|
||||
|
||||
# Add cache configuration to the resolved compose file
|
||||
python ../.github/workflows/scripts/docker-ci-fix-compose-build-cache.py \
|
||||
--source docker-compose.resolved.yml \
|
||||
--cache-from "type=gha" \
|
||||
--cache-to "type=gha,mode=max" \
|
||||
--backend-hash "${{ hashFiles('autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/poetry.lock', 'autogpt_platform/backend/backend') }}" \
|
||||
--frontend-hash "${{ hashFiles('autogpt_platform/frontend/Dockerfile', 'autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/src') }}" \
|
||||
--git-ref "${{ github.ref }}"
|
||||
|
||||
# Build with bake using the resolved compose file (now includes cache config)
|
||||
docker buildx bake --allow=fs.read=.. -f docker-compose.resolved.yml --load
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Set up tests - Cache E2E test data
|
||||
id: e2e-data-cache
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: /tmp/.buildx-cache
|
||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-buildx-frontend-test-
|
||||
path: /tmp/e2e_test_data.sql
|
||||
key: e2e-test-data-${{ hashFiles('autogpt_platform/backend/test/e2e_test_data.py', 'autogpt_platform/backend/migrations/**', '.github/workflows/platform-frontend-ci.yml') }}
|
||||
|
||||
- name: Run docker compose
|
||||
- name: Set up Platform - Start Supabase DB + Auth
|
||||
run: |
|
||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||
docker compose -f ../docker-compose.resolved.yml up -d db auth --no-build
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for auth service to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -c "SELECT 1 FROM auth.users LIMIT 1" 2>/dev/null; do sleep 2; done' || echo "Auth schema check timeout, continuing..."
|
||||
|
||||
- name: Set up Platform - Run migrations
|
||||
run: |
|
||||
echo "Running migrations..."
|
||||
docker compose -f ../docker-compose.resolved.yml run --rm migrate
|
||||
echo "✅ Migrations completed"
|
||||
env:
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Move cache
|
||||
- name: Set up tests - Load cached E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
rm -rf /tmp/.buildx-cache
|
||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
||||
fi
|
||||
echo "✅ Found cached E2E test data, restoring..."
|
||||
{
|
||||
echo "SET session_replication_role = 'replica';"
|
||||
cat /tmp/e2e_test_data.sql
|
||||
echo "SET session_replication_role = 'origin';"
|
||||
} | docker compose -f ../docker-compose.resolved.yml exec -T db psql -U postgres -d postgres -b
|
||||
# Refresh materialized views after restore
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
psql -U postgres -d postgres -b -c "SET search_path TO platform; SELECT refresh_store_materialized_views();" || true
|
||||
|
||||
- name: Wait for services to be ready
|
||||
echo "✅ E2E test data restored from cache"
|
||||
|
||||
- name: Set up Platform - Start (all other services)
|
||||
run: |
|
||||
docker compose -f ../docker-compose.resolved.yml up -d --no-build
|
||||
echo "Waiting for rest_server to be ready..."
|
||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||
echo "Waiting for database to be ready..."
|
||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||
env:
|
||||
NEXT_PUBLIC_PW_TEST: true
|
||||
|
||||
- name: Create E2E test data
|
||||
- name: Set up tests - Create E2E test data
|
||||
if: steps.e2e-data-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
echo "Creating E2E test data..."
|
||||
# First try to run the script from inside the container
|
||||
if docker compose -f ../docker-compose.yml exec -T rest_server test -f /app/autogpt_platform/backend/test/e2e_test_data.py; then
|
||||
echo "✅ Found e2e_test_data.py in container, running it..."
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python backend/test/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
echo "⚠️ e2e_test_data.py not found in container, copying and running..."
|
||||
# Copy the script into the container and run it
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.yml ps -q rest_server):/tmp/e2e_test_data.py || {
|
||||
echo "❌ Failed to copy script to container"
|
||||
exit 1
|
||||
}
|
||||
docker compose -f ../docker-compose.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
docker cp ../backend/test/e2e_test_data.py $(docker compose -f ../docker-compose.resolved.yml ps -q rest_server):/tmp/e2e_test_data.py
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T rest_server sh -c "cd /app/autogpt_platform && python /tmp/e2e_test_data.py" || {
|
||||
echo "❌ E2E test data creation failed!"
|
||||
docker compose -f ../docker-compose.resolved.yml logs --tail=50 rest_server
|
||||
exit 1
|
||||
}
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
# Dump auth.users + platform schema for cache (two separate dumps)
|
||||
echo "Dumping database for cache..."
|
||||
{
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--table='auth.users' postgres
|
||||
docker compose -f ../docker-compose.resolved.yml exec -T db \
|
||||
pg_dump -U postgres --data-only --column-inserts \
|
||||
--schema=platform \
|
||||
--exclude-table='platform._prisma_migrations' \
|
||||
--exclude-table='platform.apscheduler_jobs' \
|
||||
--exclude-table='platform.apscheduler_jobs_batched_notifications' \
|
||||
postgres
|
||||
} > /tmp/e2e_test_data.sql
|
||||
|
||||
echo "✅ Database dump created for caching ($(wc -l < /tmp/e2e_test_data.sql) lines)"
|
||||
|
||||
- name: Set up tests - Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Set up tests - Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Set up tests - Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Install Browser 'chromium'
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
@@ -269,7 +281,7 @@ jobs:
|
||||
|
||||
- name: Print Final Docker Compose logs
|
||||
if: always()
|
||||
run: docker compose -f ../docker-compose.yml logs
|
||||
run: docker compose -f ../docker-compose.resolved.yml logs
|
||||
|
||||
integration_test:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -281,22 +293,15 @@ jobs:
|
||||
with:
|
||||
submodules: recursive
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22.18.0"
|
||||
|
||||
- name: Enable corepack
|
||||
run: corepack enable
|
||||
|
||||
- name: Restore dependencies cache
|
||||
uses: actions/cache@v5
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
path: ~/.pnpm-store
|
||||
key: ${{ needs.setup.outputs.cache-key }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
${{ runner.os }}-pnpm-
|
||||
node-version: "22.18.0"
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
195
.github/workflows/scripts/docker-ci-fix-compose-build-cache.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Add cache configuration to a resolved docker-compose file for all services
|
||||
that have a build key, and ensure image names match what docker compose expects.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_BRANCH = "dev"
|
||||
CACHE_BUILDS_FOR_COMPONENTS = ["backend", "frontend"]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add cache config to a resolved compose file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
required=True,
|
||||
help="Source compose file to read (should be output of `docker compose config`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-from",
|
||||
default="type=gha",
|
||||
help="Cache source configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-to",
|
||||
default="type=gha,mode=max",
|
||||
help="Cache destination configuration",
|
||||
)
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
parser.add_argument(
|
||||
f"--{component}-hash",
|
||||
default="",
|
||||
help=f"Hash for {component} cache scope (e.g., from hashFiles())",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--git-ref",
|
||||
default="",
|
||||
help="Git ref for branch-based cache scope (e.g., refs/heads/master)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Normalize git ref to a safe scope name (e.g., refs/heads/master -> master)
|
||||
git_ref_scope = ""
|
||||
if args.git_ref:
|
||||
git_ref_scope = args.git_ref.replace("refs/heads/", "").replace("/", "-")
|
||||
|
||||
with open(args.source, "r") as f:
|
||||
compose = yaml.safe_load(f)
|
||||
|
||||
# Get project name from compose file or default
|
||||
project_name = compose.get("name", "autogpt_platform")
|
||||
|
||||
def get_image_name(dockerfile: str, target: str) -> str:
|
||||
"""Generate image name based on Dockerfile folder and build target."""
|
||||
dockerfile_parts = dockerfile.replace("\\", "/").split("/")
|
||||
if len(dockerfile_parts) >= 2:
|
||||
folder_name = dockerfile_parts[-2] # e.g., "backend" or "frontend"
|
||||
else:
|
||||
folder_name = "app"
|
||||
return f"{project_name}-{folder_name}:{target}"
|
||||
|
||||
def get_build_key(dockerfile: str, target: str) -> str:
|
||||
"""Generate a unique key for a Dockerfile+target combination."""
|
||||
return f"{dockerfile}:{target}"
|
||||
|
||||
def get_component(dockerfile: str) -> str | None:
|
||||
"""Get component name (frontend/backend) from dockerfile path."""
|
||||
for component in CACHE_BUILDS_FOR_COMPONENTS:
|
||||
if component in dockerfile:
|
||||
return component
|
||||
return None
|
||||
|
||||
# First pass: collect all services with build configs and identify duplicates
|
||||
# Track which (dockerfile, target) combinations we've seen
|
||||
build_key_to_first_service: dict[str, str] = {}
|
||||
services_to_build: list[str] = []
|
||||
services_to_dedupe: list[str] = []
|
||||
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "default")
|
||||
build_key = get_build_key(dockerfile, target)
|
||||
|
||||
if build_key not in build_key_to_first_service:
|
||||
# First service with this build config - it will do the actual build
|
||||
build_key_to_first_service[build_key] = service_name
|
||||
services_to_build.append(service_name)
|
||||
else:
|
||||
# Duplicate - will just use the image from the first service
|
||||
services_to_dedupe.append(service_name)
|
||||
|
||||
# Second pass: configure builds and deduplicate
|
||||
modified_services = []
|
||||
for service_name, service_config in compose.get("services", {}).items():
|
||||
if "build" not in service_config:
|
||||
continue
|
||||
|
||||
build_config = service_config["build"]
|
||||
dockerfile = build_config.get("dockerfile", "Dockerfile")
|
||||
target = build_config.get("target", "latest")
|
||||
image_name = get_image_name(dockerfile, target)
|
||||
|
||||
# Set image name for all services (needed for both builders and deduped)
|
||||
service_config["image"] = image_name
|
||||
|
||||
if service_name in services_to_dedupe:
|
||||
# Remove build config - this service will use the pre-built image
|
||||
del service_config["build"]
|
||||
continue
|
||||
|
||||
# This service will do the actual build - add cache config
|
||||
cache_from_list = []
|
||||
cache_to_list = []
|
||||
|
||||
component = get_component(dockerfile)
|
||||
if not component:
|
||||
# Skip services that don't clearly match frontend/backend
|
||||
continue
|
||||
|
||||
# Get the hash for this component
|
||||
component_hash = getattr(args, f"{component}_hash")
|
||||
|
||||
# Scope format: platform-{component}-{target}-{hash|ref}
|
||||
# Example: platform-backend-server-abc123
|
||||
|
||||
if "type=gha" in args.cache_from:
|
||||
# 1. Primary: exact hash match (most specific)
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={hash_scope}")
|
||||
|
||||
# 2. Fallback: branch-based cache
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={ref_scope}")
|
||||
|
||||
# 3. Fallback: dev branch cache (for PRs/feature branches)
|
||||
if git_ref_scope and git_ref_scope != DEFAULT_BRANCH:
|
||||
master_scope = f"platform-{component}-{target}-{DEFAULT_BRANCH}"
|
||||
cache_from_list.append(f"{args.cache_from},scope={master_scope}")
|
||||
|
||||
if "type=gha" in args.cache_to:
|
||||
# Write to both hash-based and branch-based scopes
|
||||
if component_hash:
|
||||
hash_scope = f"platform-{component}-{target}-{component_hash}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={hash_scope}")
|
||||
|
||||
if git_ref_scope:
|
||||
ref_scope = f"platform-{component}-{target}-{git_ref_scope}"
|
||||
cache_to_list.append(f"{args.cache_to},scope={ref_scope}")
|
||||
|
||||
# Ensure we have at least one cache source/target
|
||||
if not cache_from_list:
|
||||
cache_from_list.append(args.cache_from)
|
||||
if not cache_to_list:
|
||||
cache_to_list.append(args.cache_to)
|
||||
|
||||
build_config["cache_from"] = cache_from_list
|
||||
build_config["cache_to"] = cache_to_list
|
||||
modified_services.append(service_name)
|
||||
|
||||
# Write back to the same file
|
||||
with open(args.source, "w") as f:
|
||||
yaml.dump(compose, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
print(f"Added cache config to {len(modified_services)} services in {args.source}:")
|
||||
for svc in modified_services:
|
||||
svc_config = compose["services"][svc]
|
||||
build_cfg = svc_config.get("build", {})
|
||||
cache_from_list = build_cfg.get("cache_from", ["none"])
|
||||
cache_to_list = build_cfg.get("cache_to", ["none"])
|
||||
print(f" - {svc}")
|
||||
print(f" image: {svc_config.get('image', 'N/A')}")
|
||||
print(f" cache_from: {cache_from_list}")
|
||||
print(f" cache_to: {cache_to_list}")
|
||||
if services_to_dedupe:
|
||||
print(
|
||||
f"Deduplicated {len(services_to_dedupe)} services (will use pre-built images):"
|
||||
)
|
||||
for svc in services_to_dedupe:
|
||||
print(f" - {svc} -> {compose['services'][svc].get('image', 'N/A')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -45,6 +45,11 @@ AutoGPT Platform is a monorepo containing:
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Branching Strategy
|
||||
|
||||
- **`dev`** is the main development branch. All PRs should target `dev`.
|
||||
- **`master`** is the production branch. Only used for production releases.
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
|
||||
169
autogpt_platform/autogpt_libs/poetry.lock
generated
169
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -448,61 +448,61 @@ toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "46.0.4"
|
||||
version = "46.0.5"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"},
|
||||
{file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"},
|
||||
{file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"},
|
||||
{file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"},
|
||||
{file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"},
|
||||
{file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -516,7 +516,7 @@ nox = ["nox[uv] (>=2024.4.15)"]
|
||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
@@ -570,24 +570,25 @@ tests = ["coverage", "coveralls", "dill", "mock", "nose"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
version = "0.128.7"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d"},
|
||||
{file = "fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a"},
|
||||
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
||||
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
annotated-doc = ">=0.0.2"
|
||||
pydantic = ">=2.7.0"
|
||||
starlette = ">=0.40.0,<0.51.0"
|
||||
starlette = ">=0.40.0,<1.0.0"
|
||||
typing-extensions = ">=4.8.0"
|
||||
typing-inspection = ">=0.4.2"
|
||||
|
||||
[package.extras]
|
||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||
standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
|
||||
|
||||
@@ -1062,14 +1063,14 @@ urllib3 = ">=1.26.0,<3"
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-server-sdk"
|
||||
version = "9.14.1"
|
||||
version = "9.15.0"
|
||||
description = "LaunchDarkly SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1478,14 +1479,14 @@ testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-2.27.2-py3-none-any.whl", hash = "sha256:1666fef3de05ca097a314433dd5ae2f2d71c613cb7b233d0f468c4ffe37277da"},
|
||||
{file = "postgrest-2.27.2.tar.gz", hash = "sha256:55407d530b5af3d64e883a71fec1f345d369958f723ce4a8ab0b7d169e313242"},
|
||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2248,14 +2249,14 @@ cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.27.2-py3-none-any.whl", hash = "sha256:34a9cbb26a274e707e8fc9e3ee0a66de944beac0fe604dc336d1e985db2c830f"},
|
||||
{file = "realtime-2.27.2.tar.gz", hash = "sha256:b960a90294d2cea1b3f1275ecb89204304728e08fff1c393cc1b3150739556b3"},
|
||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2436,14 +2437,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-2.27.2-py3-none-any.whl", hash = "sha256:e6f16e7a260729e7b1f46e9bf61746805a02e30f5e419ee1291007c432e3ec63"},
|
||||
{file = "storage3-2.27.2.tar.gz", hash = "sha256:cb4807b7f86b4bb1272ac6fdd2f3cfd8ba577297046fa5f88557425200275af5"},
|
||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2487,35 +2488,35 @@ python-dateutil = ">=2.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.27.2-py3-none-any.whl", hash = "sha256:d4dce00b3a418ee578017ec577c0e5be47a9a636355009c76f20ed2faa15bc54"},
|
||||
{file = "supabase-2.27.2.tar.gz", hash = "sha256:2aed40e4f3454438822442a1e94a47be6694c2c70392e7ae99b51a226d4293f7"},
|
||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = "2.27.2"
|
||||
realtime = "2.27.2"
|
||||
storage3 = "2.27.2"
|
||||
supabase-auth = "2.27.2"
|
||||
supabase-functions = "2.27.2"
|
||||
postgrest = "2.28.0"
|
||||
realtime = "2.28.0"
|
||||
storage3 = "2.28.0"
|
||||
supabase-auth = "2.28.0"
|
||||
supabase-functions = "2.28.0"
|
||||
yarl = ">=1.22.0"
|
||||
|
||||
[[package]]
|
||||
name = "supabase-auth"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_auth-2.27.2-py3-none-any.whl", hash = "sha256:78ec25b11314d0a9527a7205f3b1c72560dccdc11b38392f80297ef98664ee91"},
|
||||
{file = "supabase_auth-2.27.2.tar.gz", hash = "sha256:0f5bcc79b3677cb42e9d321f3c559070cfa40d6a29a67672cc8382fb7dc2fe97"},
|
||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2525,14 +2526,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
|
||||
[[package]]
|
||||
name = "supabase-functions"
|
||||
version = "2.27.2"
|
||||
version = "2.28.0"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_functions-2.27.2-py3-none-any.whl", hash = "sha256:db480efc669d0bca07605b9b6f167312af43121adcc842a111f79bea416ef754"},
|
||||
{file = "supabase_functions-2.27.2.tar.gz", hash = "sha256:d0c8266207a94371cb3fd35ad3c7f025b78a97cf026861e04ccd35ac1775f80b"},
|
||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2911,4 +2912,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "40eae94995dc0a388fa832ed4af9b6137f28d5b5ced3aaea70d5f91d4d9a179d"
|
||||
content-hash = "9619cae908ad38fa2c48016a58bcf4241f6f5793aa0e6cc140276e91c433cbbb"
|
||||
|
||||
@@ -11,14 +11,14 @@ python = ">=3.10,<4.0"
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
fastapi = "^0.128.7"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
launchdarkly-server-sdk = "^9.15.0"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
supabase = "^2.28.0"
|
||||
uvicorn = "^0.40.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# ============================ DEPENDENCY BUILDER ============================ #
|
||||
|
||||
FROM debian:13-slim AS builder
|
||||
|
||||
# Set environment variables
|
||||
@@ -51,7 +53,9 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
FROM debian:13-slim AS server_dependencies
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
|
||||
FROM debian:13-slim AS server
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -63,15 +67,14 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
@@ -81,30 +84,54 @@ COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
||||
RUN mkdir -p /app/autogpt_platform/backend
|
||||
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
FROM server_dependencies AS migrate
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
# The .venv includes the generated Prisma client
|
||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
# Migration stage only needs schema and migrations - much lighter than full backend
|
||||
COPY autogpt_platform/backend/schema.prisma /app/autogpt_platform/backend/
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py /app/autogpt_platform/backend/backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/migrations /app/autogpt_platform/backend/migrations
|
||||
# Copy dependency files + autogpt_libs (path dependency)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
# Copy backend code + docs (for Copilot docs search)
|
||||
COPY autogpt_platform/backend ./
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
|
||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||
FROM debian:13-slim AS migrate
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install only what's needed for prisma migrate: Node.js and minimal Python for prisma-python
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Node.js from builder (needed for Prisma CLI)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
|
||||
# Copy Prisma binaries
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
# Install prisma-client-py directly (much smaller than copying full venv)
|
||||
RUN pip3 install prisma>=0.15.0 --break-system-packages
|
||||
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Common test fixtures for server tests."""
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,20 +10,28 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AgentsFoundResponse,
|
||||
BlockDetailsResponse,
|
||||
BlockListResponse,
|
||||
BlockOutputResponse,
|
||||
ClarificationNeededResponse,
|
||||
@@ -40,6 +48,7 @@ from .tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -301,7 +310,7 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
_session = await _validate_and_get_session(session_id, user_id) # noqa: F841
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
@@ -336,82 +345,20 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
# Enqueue the task to RabbitMQ for processing by the CoPilot executor
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
@@ -971,6 +918,7 @@ ToolResponseUnion = (
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| BlockListResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
| DocSearchResultsResponse
|
||||
| DocPageResponse
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {}
|
||||
mock.categories = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestFindBlockFiltering:
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
def test_excluded_block_types_contains_expected_types(self):
|
||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_filtered_from_results(self):
|
||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
# Should only return the standard block, not the INPUT block
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestRunBlockFiltering:
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_returns_error(self):
|
||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
standard_block = make_mock_block(
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
@@ -40,11 +40,11 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
@@ -38,7 +38,9 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -48,6 +50,7 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -27,6 +28,54 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""CoPilot module - AI assistant for AutoGPT platform.
|
||||
|
||||
This module contains the core CoPilot functionality including:
|
||||
- AI generation service (LLM calls)
|
||||
- Tool execution
|
||||
- Session management
|
||||
- Stream registry for SSE reconnection
|
||||
"""
|
||||
@@ -119,8 +119,9 @@ class ChatCompletionConsumer:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
prisma = Prisma(datasource={"url": database_url})
|
||||
await prisma.connect()
|
||||
self._prisma = prisma
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -147,7 +147,7 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
@@ -0,0 +1,5 @@
|
||||
"""CoPilot Executor - Dedicated service for AI generation and tool execution.
|
||||
|
||||
This module contains the executor service that processes CoPilot tasks
|
||||
from RabbitMQ, following the graph executor pattern.
|
||||
"""
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in self._task_locks.values():
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting cancel message consumer...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
for task_id, (future, _) in self.active_tasks.items():
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
|
||||
for task_id in completed_tasks:
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
self.active_tasks.pop(task_id, None)
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
|
||||
try:
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in copilot_service.stream_chat_completion(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
_task_id=entry.task_id,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -322,38 +292,26 @@ class ChatSession(BaseModel):
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
@@ -371,80 +329,6 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -492,7 +376,7 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
@@ -500,6 +384,45 @@ async def get_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db().get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -520,7 +443,7 @@ async def upsert_chat_session(
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
@@ -537,7 +460,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -558,6 +481,65 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
db = chat_db()
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
@@ -570,7 +552,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -582,7 +564,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -600,8 +582,9 @@ async def get_user_sessions(
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
db = chat_db()
|
||||
prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
@@ -624,7 +607,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -659,7 +642,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
@@ -676,3 +659,29 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -27,6 +27,7 @@ from openai.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
@@ -35,7 +36,6 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
@@ -1744,7 +1744,7 @@ async def _update_pending_operation(
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
@@ -27,7 +27,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
@@ -3,11 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -99,7 +97,9 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -5,9 +5,8 @@ import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -145,8 +144,9 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
graphs = await graph_db().get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -673,9 +673,10 @@ async def save_agent_to_library(
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -735,12 +736,14 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -7,10 +7,9 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await lib_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool):
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool):
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
executions = await exec_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
@@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await execution_db().get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -4,8 +4,7 @@ import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -133,7 +134,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -159,7 +160,7 @@ async def search_agents(
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -3,18 +3,17 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.copilot.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,7 +54,8 @@ class FindBlockTool(BaseTool):
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, required_inputs, and input_schema."
|
||||
"The response includes each block's id, name, and description. "
|
||||
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -107,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
@@ -124,7 +124,7 @@ class FindBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Enrich results with full block information
|
||||
# Enrich results with block information
|
||||
blocks: list[BlockInfoSummary] = []
|
||||
for result in results:
|
||||
block_id = result["content_id"]
|
||||
@@ -141,65 +141,11 @@ class FindBlockTool(BaseTool):
|
||||
):
|
||||
continue
|
||||
|
||||
# Get input/output schemas
|
||||
input_schema = {}
|
||||
output_schema = {}
|
||||
try:
|
||||
input_schema = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate input schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
output_schema = block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to generate output schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
|
||||
# Get categories from block instance
|
||||
categories = []
|
||||
if hasattr(block, "categories") and block.categories:
|
||||
categories = [cat.value for cat in block.categories]
|
||||
|
||||
# Extract required inputs for easier use
|
||||
required_inputs: list[BlockInputFieldInfo] = []
|
||||
if input_schema:
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = set(input_schema.get("required", []))
|
||||
# Get credential field names to exclude from required inputs
|
||||
credentials_fields = set(
|
||||
block.input_schema.get_credentials_fields().keys()
|
||||
)
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
# Skip credential fields - they're handled separately
|
||||
if field_name in credentials_fields:
|
||||
continue
|
||||
|
||||
required_inputs.append(
|
||||
BlockInputFieldInfo(
|
||||
name=field_name,
|
||||
type=field_schema.get("type", "string"),
|
||||
description=field_schema.get("description", ""),
|
||||
required=field_name in required_fields,
|
||||
default=field_schema.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
blocks.append(
|
||||
BlockInfoSummary(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
categories=categories,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
required_inputs=required_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -228,8 +174,7 @@ class FindBlockTool(BaseTool):
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
"To see a block's inputs/outputs and execute it, use run_block with the block's 'id' - providing no inputs."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
@@ -0,0 +1,394 @@
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.copilot.tools.models import BlockListResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str,
|
||||
name: str,
|
||||
block_type: BlockType,
|
||||
disabled: bool = False,
|
||||
input_schema: dict | None = None,
|
||||
output_schema: dict | None = None,
|
||||
credentials_fields: dict | None = None,
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = f"{name} description"
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = input_schema or {
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
mock.input_schema.get_credentials_fields.return_value = credentials_fields or {}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema or {}
|
||||
mock.categories = []
|
||||
return mock
|
||||
|
||||
|
||||
class TestFindBlockFiltering:
|
||||
"""Tests for block filtering in FindBlockTool."""
|
||||
|
||||
def test_excluded_block_types_contains_expected_types(self):
|
||||
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
|
||||
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_filtered_from_results(self):
|
||||
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||
search_results = [
|
||||
{"content_id": "input-block-id", "score": 0.9},
|
||||
{"content_id": "standard-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
standard_block = make_mock_block(
|
||||
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
"input-block-id": input_block,
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
# Should only return the standard block, not the INPUT block
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "standard-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_filtered_from_results(self):
|
||||
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
search_results = [
|
||||
{"content_id": smart_decision_id, "score": 0.9},
|
||||
{"content_id": "normal-block-id", "score": 0.8},
|
||||
]
|
||||
|
||||
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
normal_block = make_mock_block(
|
||||
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||
)
|
||||
|
||||
def mock_get_block(block_id):
|
||||
return {
|
||||
smart_decision_id: smart_block,
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||
)
|
||||
|
||||
# Should only return normal block, not SmartDecisionMakerBlock
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert len(response.blocks) == 1
|
||||
assert response.blocks[0].id == "normal-block-id"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_response_size_average_chars_per_block(self):
|
||||
"""Measure average chars per block in the serialized response."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Realistic block definitions modeled after real blocks
|
||||
block_defs = [
|
||||
{
|
||||
"id": "http-block-id",
|
||||
"name": "Send Web Request",
|
||||
"input_schema": {
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to send the request to",
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "The HTTP method to use",
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Headers to include in the request",
|
||||
},
|
||||
"json_format": {
|
||||
"type": "boolean",
|
||||
"description": "If true, send the body as JSON",
|
||||
},
|
||||
"body": {
|
||||
"type": "object",
|
||||
"description": "Form/JSON body payload",
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "HTTP credentials",
|
||||
},
|
||||
},
|
||||
"required": ["url", "method"],
|
||||
},
|
||||
"output_schema": {
|
||||
"properties": {
|
||||
"response": {
|
||||
"type": "object",
|
||||
"description": "The response from the server",
|
||||
},
|
||||
"client_error": {
|
||||
"type": "object",
|
||||
"description": "Errors on 4xx status codes",
|
||||
},
|
||||
"server_error": {
|
||||
"type": "object",
|
||||
"description": "Errors on 5xx status codes",
|
||||
},
|
||||
"error": {
|
||||
"type": "string",
|
||||
"description": "Errors for all other exceptions",
|
||||
},
|
||||
},
|
||||
},
|
||||
"credentials_fields": {"credentials": True},
|
||||
},
|
||||
{
|
||||
"id": "email-block-id",
|
||||
"name": "Send Email",
|
||||
"input_schema": {
|
||||
"properties": {
|
||||
"to_email": {
|
||||
"type": "string",
|
||||
"description": "Recipient email address",
|
||||
},
|
||||
"subject": {
|
||||
"type": "string",
|
||||
"description": "Subject of the email",
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Body of the email",
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "SMTP Config",
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "SMTP credentials",
|
||||
},
|
||||
},
|
||||
"required": ["to_email", "subject", "body", "credentials"],
|
||||
},
|
||||
"output_schema": {
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"description": "Status of the email sending operation",
|
||||
},
|
||||
"error": {
|
||||
"type": "string",
|
||||
"description": "Error message if sending failed",
|
||||
},
|
||||
},
|
||||
},
|
||||
"credentials_fields": {"credentials": True},
|
||||
},
|
||||
{
|
||||
"id": "claude-code-block-id",
|
||||
"name": "Claude Code",
|
||||
"input_schema": {
|
||||
"properties": {
|
||||
"e2b_credentials": {
|
||||
"type": "object",
|
||||
"description": "API key for E2B platform",
|
||||
},
|
||||
"anthropic_credentials": {
|
||||
"type": "object",
|
||||
"description": "API key for Anthropic",
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Task or instruction for Claude Code",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Sandbox timeout in seconds",
|
||||
},
|
||||
"setup_commands": {
|
||||
"type": "array",
|
||||
"description": "Shell commands to run before execution",
|
||||
},
|
||||
"working_directory": {
|
||||
"type": "string",
|
||||
"description": "Working directory for Claude Code",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Session ID to resume a conversation",
|
||||
},
|
||||
"sandbox_id": {
|
||||
"type": "string",
|
||||
"description": "Sandbox ID to reconnect to",
|
||||
},
|
||||
"conversation_history": {
|
||||
"type": "string",
|
||||
"description": "Previous conversation history",
|
||||
},
|
||||
"dispose_sandbox": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to dispose sandbox after execution",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"e2b_credentials",
|
||||
"anthropic_credentials",
|
||||
"prompt",
|
||||
],
|
||||
},
|
||||
"output_schema": {
|
||||
"properties": {
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "Output from Claude Code execution",
|
||||
},
|
||||
"files": {
|
||||
"type": "array",
|
||||
"description": "Files created/modified by Claude Code",
|
||||
},
|
||||
"conversation_history": {
|
||||
"type": "string",
|
||||
"description": "Full conversation history",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Session ID for this conversation",
|
||||
},
|
||||
"sandbox_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the sandbox instance",
|
||||
},
|
||||
"error": {
|
||||
"type": "string",
|
||||
"description": "Error message if execution failed",
|
||||
},
|
||||
},
|
||||
},
|
||||
"credentials_fields": {
|
||||
"e2b_credentials": True,
|
||||
"anthropic_credentials": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
search_results = [
|
||||
{"content_id": d["id"], "score": 0.9 - i * 0.1}
|
||||
for i, d in enumerate(block_defs)
|
||||
]
|
||||
mock_blocks = {
|
||||
d["id"]: make_mock_block(
|
||||
block_id=d["id"],
|
||||
name=d["name"],
|
||||
block_type=BlockType.STANDARD,
|
||||
input_schema=d["input_schema"],
|
||||
output_schema=d["output_schema"],
|
||||
credentials_fields=d["credentials_fields"],
|
||||
)
|
||||
for d in block_defs
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, len(search_results)),
|
||||
), patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query="test"
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.count == len(block_defs)
|
||||
|
||||
total_chars = len(response.model_dump_json())
|
||||
avg_chars = total_chars // response.count
|
||||
|
||||
# Print for visibility in test output
|
||||
print(f"\nTotal response size: {total_chars} chars")
|
||||
print(f"Number of blocks: {response.count}")
|
||||
print(f"Average chars per block: {avg_chars}")
|
||||
|
||||
# The old response was ~90K for 10 blocks (~9K per block).
|
||||
# Previous optimization reduced it to ~1.5K per block (no raw JSON schemas).
|
||||
# Now with only id/name/description, we expect ~300 chars per block.
|
||||
assert avg_chars < 500, (
|
||||
f"Average chars per block ({avg_chars}) exceeds 500. "
|
||||
f"Total response: {total_chars} chars for {response.count} blocks."
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -4,9 +4,9 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
@@ -25,6 +25,7 @@ class ResponseType(str, Enum):
|
||||
AGENT_SAVED = "agent_saved"
|
||||
CLARIFICATION_NEEDED = "clarification_needed"
|
||||
BLOCK_LIST = "block_list"
|
||||
BLOCK_DETAILS = "block_details"
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
@@ -334,13 +335,6 @@ class BlockInfoSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
categories: list[str]
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
||||
default_factory=list,
|
||||
description="List of required input fields for this block",
|
||||
)
|
||||
|
||||
|
||||
class BlockListResponse(ToolResponseBase):
|
||||
@@ -350,10 +344,25 @@ class BlockListResponse(ToolResponseBase):
|
||||
blocks: list[BlockInfoSummary]
|
||||
count: int
|
||||
query: str
|
||||
usage_hint: str = Field(
|
||||
default="To execute a block, call run_block with block_id set to the block's "
|
||||
"'id' field and input_data containing the required fields from input_schema."
|
||||
)
|
||||
|
||||
|
||||
class BlockDetails(BaseModel):
|
||||
"""Detailed block information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
inputs: dict[str, Any] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
|
||||
|
||||
class BlockDetailsResponse(ToolResponseBase):
|
||||
"""Response for block details (first run_block attempt)."""
|
||||
|
||||
type: ResponseType = ResponseType.BLOCK_DETAILS
|
||||
block: BlockDetails
|
||||
user_authenticated: bool = False
|
||||
|
||||
|
||||
class BlockOutputResponse(ToolResponseBase):
|
||||
@@ -5,16 +5,12 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -200,7 +196,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
library_agent = await library_db().get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -209,9 +205,7 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
graph = await graph_db().get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -522,7 +516,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await get_user_by_id(user_id)
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -7,24 +7,27 @@ from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -51,8 +54,8 @@ class RunBlockTool(BaseTool):
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"Use the 'id' from find_block results and provide input_data "
|
||||
"matching the block's required_inputs."
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -67,11 +70,19 @@ class RunBlockTool(BaseTool):
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"block_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's human-readable name from find_block results. "
|
||||
"Used for display purposes in the UI."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Input values for the block. Use the 'required_inputs' field "
|
||||
"from find_block to see what fields are needed."
|
||||
"Input values for the block. "
|
||||
"First call with empty {} to see the block's schema, "
|
||||
"then call again with proper values to execute."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -156,6 +167,34 @@ class RunBlockTool(BaseTool):
|
||||
await self._resolve_block_credentials(user_id, block, input_data)
|
||||
)
|
||||
|
||||
# Get block schemas for details/validation
|
||||
try:
|
||||
input_schema: dict[str, Any] = block.input_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate input schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' has an invalid input schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
try:
|
||||
output_schema: dict[str, Any] = block.output_schema.jsonschema()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to generate output schema for block %s: %s",
|
||||
block_id,
|
||||
e,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block.name}' has an invalid output schema",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
@@ -188,9 +227,56 @@ class RunBlockTool(BaseTool):
|
||||
graph_version=None,
|
||||
)
|
||||
|
||||
# Check if this is a first attempt (required inputs missing)
|
||||
# Return block details so user can see what inputs are needed
|
||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||
required_keys = set(input_schema.get("required", []))
|
||||
required_non_credential_keys = required_keys - credentials_fields
|
||||
provided_input_keys = set(input_data.keys()) - credentials_fields
|
||||
|
||||
# Check for unknown input fields
|
||||
valid_fields = (
|
||||
set(input_schema.get("properties", {}).keys()) - credentials_fields
|
||||
)
|
||||
unrecognized_fields = provided_input_keys - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Block was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=input_schema,
|
||||
)
|
||||
|
||||
# Show details when not all required non-credential inputs are provided
|
||||
if not (required_non_credential_keys <= provided_input_keys):
|
||||
# Get credentials info for the response
|
||||
credentials_meta = []
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
credentials_meta.append(cred_meta)
|
||||
|
||||
return BlockDetailsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' details. "
|
||||
"Provide input_data matching the inputs schema to execute the block."
|
||||
),
|
||||
session_id=session_id,
|
||||
block=BlockDetails(
|
||||
id=block_id,
|
||||
name=block.name,
|
||||
description=block.description or "",
|
||||
inputs=input_schema,
|
||||
outputs=output_schema,
|
||||
credentials=credentials_meta,
|
||||
),
|
||||
user_authenticated=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
362
autogpt_platform/backend/backend/copilot/tools/run_block_test.py
Normal file
362
autogpt_platform/backend/backend/copilot/tools/run_block_test.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Tests for block execution guards and input validation in RunBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import (
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
)
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
|
||||
def make_mock_block(
|
||||
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||
):
|
||||
"""Create a mock block for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = block_type
|
||||
mock.disabled = disabled
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||
return mock
|
||||
|
||||
|
||||
def make_mock_block_with_schema(
|
||||
block_id: str,
|
||||
name: str,
|
||||
input_properties: dict,
|
||||
required_fields: list[str],
|
||||
output_properties: dict | None = None,
|
||||
):
|
||||
"""Create a mock block with a defined input/output schema for validation tests."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
mock.description = f"Test block: {name}"
|
||||
|
||||
input_schema = {
|
||||
"properties": input_properties,
|
||||
"required": required_fields,
|
||||
}
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = input_schema
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
|
||||
output_schema = {
|
||||
"properties": output_properties or {"result": {"type": "string"}},
|
||||
}
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = output_schema
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
class TestRunBlockFiltering:
|
||||
"""Tests for block execution guards in RunBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_type_returns_error(self):
|
||||
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="input-block-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
assert "designed for use within graphs only" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_excluded_block_id_returns_error(self):
|
||||
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
smart_block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id=smart_decision_id,
|
||||
input_data={},
|
||||
)
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert "cannot be run directly in CoPilot" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_excluded_block_passes_guard(self):
|
||||
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
standard_block = make_mock_block(
|
||||
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="standard-id",
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||
if isinstance(response, ErrorResponse):
|
||||
assert "cannot be run directly in CoPilot" not in response.message
|
||||
|
||||
|
||||
class TestRunBlockInputValidation:
|
||||
"""Tests for input field validation in RunBlockTool.
|
||||
|
||||
run_block rejects unknown input field names with InputValidationErrorResponse,
|
||||
preventing silent failures where incorrect keys would be ignored and the block
|
||||
would execute with default values instead of the caller's intended values.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_input_fields_are_rejected(self):
|
||||
"""run_block rejects unknown input fields instead of silently ignoring them.
|
||||
|
||||
Scenario: The AI Text Generator block has a field called 'model' (for LLM model
|
||||
selection), but the LLM calling the tool guesses wrong and sends 'LLM_Model'
|
||||
instead. The block should reject the request and return the valid schema.
|
||||
"""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="ai-text-gen-id",
|
||||
name="AI Text Generator",
|
||||
input_properties={
|
||||
"prompt": {"type": "string", "description": "The prompt to send"},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "The LLM model to use",
|
||||
"default": "gpt-4o-mini",
|
||||
},
|
||||
"sys_prompt": {
|
||||
"type": "string",
|
||||
"description": "System prompt",
|
||||
"default": "",
|
||||
},
|
||||
},
|
||||
required_fields=["prompt"],
|
||||
output_properties={"response": {"type": "string"}},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Provide 'prompt' (correct) but 'LLM_Model' instead of 'model' (wrong key)
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku about coding",
|
||||
"LLM_Model": "claude-opus-4-6", # WRONG KEY - should be 'model'
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
assert "Block was not executed" in response.message
|
||||
assert "inputs" in response.model_dump() # valid schema included
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_multiple_wrong_keys_are_all_reported(self):
|
||||
"""All unrecognized field names are reported in a single error response."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="ai-text-gen-id",
|
||||
name="AI Text Generator",
|
||||
input_properties={
|
||||
"prompt": {"type": "string"},
|
||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||
"sys_prompt": {"type": "string", "default": ""},
|
||||
"retry": {"type": "integer", "default": 3},
|
||||
},
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Hello", # correct
|
||||
"llm_model": "claude-opus-4-6", # WRONG - should be 'model'
|
||||
"system_prompt": "Be helpful", # WRONG - should be 'sys_prompt'
|
||||
"retries": 5, # WRONG - should be 'retry'
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert set(response.unrecognized_fields) == {
|
||||
"llm_model",
|
||||
"system_prompt",
|
||||
"retries",
|
||||
}
|
||||
assert "Block was not executed" in response.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_fields_rejected_even_with_missing_required(self):
|
||||
"""Unknown fields are caught before the missing-required-fields check."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="ai-text-gen-id",
|
||||
name="AI Text Generator",
|
||||
input_properties={
|
||||
"prompt": {"type": "string"},
|
||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||
},
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# 'prompt' is missing AND 'LLM_Model' is an unknown field
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"LLM_Model": "claude-opus-4-6", # wrong key, and 'prompt' is missing
|
||||
},
|
||||
)
|
||||
|
||||
# Unknown fields are caught first
|
||||
assert isinstance(response, InputValidationErrorResponse)
|
||||
assert "LLM_Model" in response.unrecognized_fields
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_correct_inputs_still_execute(self):
|
||||
"""Correct input field names pass validation and the block executes."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="ai-text-gen-id",
|
||||
name="AI Text Generator",
|
||||
input_properties={
|
||||
"prompt": {"type": "string"},
|
||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||
},
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
async def mock_execute(input_data, **kwargs):
|
||||
yield "response", "Generated text"
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(id="test-workspace-id"),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"prompt": "Write a haiku",
|
||||
"model": "gpt-4o-mini", # correct field name
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockOutputResponse)
|
||||
assert response.success is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_missing_required_fields_returns_details(self):
|
||||
"""Missing required fields returns BlockDetailsResponse with schema."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
mock_block = make_mock_block_with_schema(
|
||||
block_id="ai-text-gen-id",
|
||||
name="AI Text Generator",
|
||||
input_properties={
|
||||
"prompt": {"type": "string"},
|
||||
"model": {"type": "string", "default": "gpt-4o-mini"},
|
||||
},
|
||||
required_fields=["prompt"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
|
||||
# Only provide valid optional field, missing required 'prompt'
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="ai-text-gen-id",
|
||||
input_data={
|
||||
"model": "gpt-4o-mini", # valid but optional
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
@@ -5,16 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Tests for BlockDetailsResponse in RunBlockTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import BlockDetailsResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block-details"
|
||||
|
||||
|
||||
def make_mock_block_with_inputs(
|
||||
block_id: str, name: str, description: str = "Test description"
|
||||
):
|
||||
"""Create a mock block with input/output schemas for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = block_id
|
||||
mock.name = name
|
||||
mock.description = description
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
|
||||
# Input schema with non-credential fields
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to fetch"},
|
||||
"method": {"type": "string", "description": "HTTP method"},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
mock.input_schema.get_credentials_fields.return_value = {}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
# Output schema
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {
|
||||
"properties": {
|
||||
"response": {"type": "object", "description": "HTTP response"},
|
||||
"error": {"type": "string", "description": "Error message"},
|
||||
}
|
||||
}
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_block_returns_details_when_no_input_provided():
|
||||
"""When run_block is called without input_data, it should return BlockDetailsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Create a block with inputs
|
||||
http_block = make_mock_block_with_inputs(
|
||||
"http-block-id", "HTTP Request", "Send HTTP requests"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=http_block,
|
||||
):
|
||||
# Mock credentials check to return no missing credentials
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=({}, []), # (matched_credentials, missing_credentials)
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="http-block-id",
|
||||
input_data={}, # Empty input data
|
||||
)
|
||||
|
||||
# Should return BlockDetailsResponse showing the schema
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.id == "http-block-id"
|
||||
assert response.block.name == "HTTP Request"
|
||||
assert response.block.description == "Send HTTP requests"
|
||||
assert "url" in response.block.inputs["properties"]
|
||||
assert "method" in response.block.inputs["properties"]
|
||||
assert "response" in response.block.outputs["properties"]
|
||||
assert response.user_authenticated is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
"""When only credentials are provided (no actual input), should return details."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
|
||||
# Create a block with both credential and non-credential inputs
|
||||
mock = MagicMock()
|
||||
mock.id = "api-block-id"
|
||||
mock.name = "API Call"
|
||||
mock.description = "Make API calls"
|
||||
mock.block_type = BlockType.STANDARD
|
||||
mock.disabled = False
|
||||
|
||||
mock.input_schema = MagicMock()
|
||||
mock.input_schema.jsonschema.return_value = {
|
||||
"properties": {
|
||||
"credentials": {"type": "object", "description": "API credentials"},
|
||||
"endpoint": {"type": "string", "description": "API endpoint"},
|
||||
},
|
||||
"required": ["credentials", "endpoint"],
|
||||
}
|
||||
mock.input_schema.get_credentials_fields.return_value = {"credentials": True}
|
||||
mock.input_schema.get_credentials_fields_info.return_value = {}
|
||||
|
||||
mock.output_schema = MagicMock()
|
||||
mock.output_schema.jsonschema.return_value = {
|
||||
"properties": {"result": {"type": "object"}}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock,
|
||||
):
|
||||
with patch.object(
|
||||
RunBlockTool,
|
||||
"_resolve_block_credentials",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
{
|
||||
"credentials": CredentialsMetaInput(
|
||||
id="cred-id",
|
||||
provider=ProviderName("test_provider"),
|
||||
type="api_key",
|
||||
title="Test Credential",
|
||||
)
|
||||
},
|
||||
[],
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
block_id="api-block-id",
|
||||
input_data={"credentials": {"some": "cred"}}, # Only credential
|
||||
)
|
||||
|
||||
# Should return details because no non-credential inputs provided
|
||||
assert isinstance(response, BlockDetailsResponse)
|
||||
assert response.block.id == "api-block-id"
|
||||
assert response.block.name == "API Call"
|
||||
@@ -3,9 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
@@ -38,13 +37,14 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await store_db.get_available_graph(
|
||||
graph = await sdb.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -209,13 +209,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db.create_library_agent(
|
||||
library_agents = await library_db().create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -146,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -280,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -478,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -577,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from backend.data import db
|
||||
|
||||
|
||||
def chat_db():
|
||||
if db.is_connected():
|
||||
from backend.copilot import db as _chat_db
|
||||
|
||||
chat_db = _chat_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
chat_db = get_database_manager_async_client()
|
||||
|
||||
return chat_db
|
||||
|
||||
|
||||
def graph_db():
|
||||
if db.is_connected():
|
||||
from backend.data import graph as _graph_db
|
||||
|
||||
graph_db = _graph_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
graph_db = get_database_manager_async_client()
|
||||
|
||||
return graph_db
|
||||
|
||||
|
||||
def library_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.library import db as _library_db
|
||||
|
||||
library_db = _library_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
library_db = get_database_manager_async_client()
|
||||
|
||||
return library_db
|
||||
|
||||
|
||||
def store_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import db as _store_db
|
||||
|
||||
store_db = _store_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
store_db = get_database_manager_async_client()
|
||||
|
||||
return store_db
|
||||
|
||||
|
||||
def search():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import hybrid_search as _search
|
||||
|
||||
search = _search
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
search = get_database_manager_async_client()
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def execution_db():
|
||||
if db.is_connected():
|
||||
from backend.data import execution as _execution_db
|
||||
|
||||
execution_db = _execution_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
execution_db = get_database_manager_async_client()
|
||||
|
||||
return execution_db
|
||||
|
||||
|
||||
def user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import user as _user_db
|
||||
|
||||
user_db = _user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
user_db = get_database_manager_async_client()
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
understanding_db = _understanding_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
understanding_db = get_database_manager_async_client()
|
||||
|
||||
return understanding_db
|
||||
|
||||
|
||||
def workspace_db():
|
||||
if db.is_connected():
|
||||
from backend.data import workspace as _workspace_db
|
||||
|
||||
workspace_db = _workspace_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
@@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
list_library_agents,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
get_agent,
|
||||
get_available_graph,
|
||||
get_store_agent_details,
|
||||
get_store_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.copilot import db as chat_db
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -48,6 +60,7 @@ from backend.data.graph import (
|
||||
get_graph_metadata,
|
||||
get_graph_settings,
|
||||
get_node,
|
||||
get_store_listed_graphs,
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
@@ -67,6 +80,10 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -76,6 +93,7 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
This service connects to the Prisma engine and exposes database
|
||||
operations via RPC endpoints. It acts as a centralized connection pool
|
||||
for all services that need database access.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
@@ -142,11 +167,15 @@ class DatabaseManager(AppService):
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
||||
to the function's type so it can be bound as a method.
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
# ============ Graph Executions ============ #
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
@@ -170,36 +199,37 @@ class DatabaseManager(AppService):
|
||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Graphs
|
||||
# ============ Graphs ============ #
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
get_graph_settings = _(get_graph_settings)
|
||||
get_store_listed_graphs = _(get_store_listed_graphs)
|
||||
|
||||
# Credits
|
||||
# ============ Credits ============ #
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
|
||||
# Notifications - async
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
@@ -212,29 +242,56 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
create_graph_in_library = _(create_graph_in_library)
|
||||
create_library_agent = _(create_library_agent)
|
||||
get_library_agent = _(get_library_agent)
|
||||
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
get_agent = _(get_agent)
|
||||
get_available_graph = _(get_available_graph)
|
||||
|
||||
# Store Embeddings
|
||||
# ============ Search ============ #
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
unified_hybrid_search = _(unified_hybrid_search)
|
||||
|
||||
# Summary data - async
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# ============ Graph Executions ============ #
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_child_graph_executions = d.get_child_graph_executions
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_graph_executions = d.get_graph_executions
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Graphs ============ #
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_node = d.get_node
|
||||
get_store_listed_graphs = d.get_store_listed_graphs
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
# User Comms
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
@@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
create_graph_in_library = d.create_graph_in_library
|
||||
create_library_agent = d.create_library_agent
|
||||
get_library_agent = d.get_library_agent
|
||||
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
get_agent = d.get_agent
|
||||
get_available_graph = d.get_available_graph
|
||||
|
||||
# Summary data
|
||||
# ============ Search ============ #
|
||||
unified_hybrid_search = d.unified_hybrid_search
|
||||
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"DatabaseManagerAsyncClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -19,6 +20,7 @@ class ClusterLock:
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = threading.Lock()
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
@@ -31,7 +33,8 @@ class ClusterLock:
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
@@ -57,23 +60,27 @@ class ClusterLock:
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# Check if we're within the rate limit period (thread-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
@@ -82,7 +89,8 @@ class ClusterLock:
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
@@ -91,25 +99,30 @@ class ClusterLock:
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -92,7 +92,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,12 +13,15 @@ if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
|
||||
@thread_cached
|
||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||
@@ -38,7 +41,7 @@ def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
@@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
||||
return client
|
||||
|
||||
|
||||
# ============ CoPilot Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
|
||||
from backend.copilot.executor.utils import create_copilot_queue_config
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
|
||||
client = AsyncRabbitMQ(create_copilot_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
# ============ Integration Credentials Store ============ #
|
||||
|
||||
|
||||
|
||||
@@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
|
||||
num_copilot_workers: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of concurrent CoPilot executor workers",
|
||||
)
|
||||
|
||||
copilot_executor_port: int = Field(
|
||||
default=8008,
|
||||
description="The port for CoPilot executor daemon to run on",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
description="The port for execution scheduler daemon to run on",
|
||||
)
|
||||
|
||||
agent_server_port: int = Field(
|
||||
default=8004,
|
||||
description="The port for agent server daemon to run on",
|
||||
)
|
||||
|
||||
database_api_port: int = Field(
|
||||
default=8005,
|
||||
description="The port for database server API to run on",
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.data import db
|
||||
from backend.data.block import initialize_blocks
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
@@ -19,7 +20,7 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
68
autogpt_platform/backend/poetry.lock
generated
68
autogpt_platform/backend/poetry.lock
generated
@@ -441,14 +441,14 @@ develop = true
|
||||
colorama = "^0.4.6"
|
||||
cryptography = "^46.0"
|
||||
expiringdict = "^1.2.2"
|
||||
fastapi = "^0.128.0"
|
||||
fastapi = "^0.128.7"
|
||||
google-cloud-logging = "^3.13.0"
|
||||
launchdarkly-server-sdk = "^9.14.1"
|
||||
launchdarkly-server-sdk = "^9.15.0"
|
||||
pydantic = "^2.12.5"
|
||||
pydantic-settings = "^2.12.0"
|
||||
pyjwt = {version = "^2.11.0", extras = ["crypto"]}
|
||||
redis = "^6.2.0"
|
||||
supabase = "^2.27.2"
|
||||
supabase = "^2.28.0"
|
||||
uvicorn = "^0.40.0"
|
||||
|
||||
[package.source]
|
||||
@@ -1382,14 +1382,14 @@ tzdata = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.6"
|
||||
version = "0.128.7"
|
||||
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "fastapi-0.128.6-py3-none-any.whl", hash = "sha256:bb1c1ef87d6086a7132d0ab60869d6f1ee67283b20fbf84ec0003bd335099509"},
|
||||
{file = "fastapi-0.128.6.tar.gz", hash = "sha256:0cb3946557e792d731b26a42b04912f16367e3c3135ea8290f620e234f2b604f"},
|
||||
{file = "fastapi-0.128.7-py3-none-any.whl", hash = "sha256:6bd9bd31cb7047465f2d3fa3ba3f33b0870b17d4eaf7cdb36d1576ab060ad662"},
|
||||
{file = "fastapi-0.128.7.tar.gz", hash = "sha256:783c273416995486c155ad2c0e2b45905dedfaf20b9ef8d9f6a9124670639a24"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3117,14 +3117,14 @@ urllib3 = ">=1.26.0,<3"
|
||||
|
||||
[[package]]
|
||||
name = "launchdarkly-server-sdk"
|
||||
version = "9.14.1"
|
||||
version = "9.15.0"
|
||||
description = "LaunchDarkly SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "launchdarkly_server_sdk-9.14.1-py3-none-any.whl", hash = "sha256:a9e2bd9ecdef845cd631ae0d4334a1115e5b44257c42eb2349492be4bac7815c"},
|
||||
{file = "launchdarkly_server_sdk-9.14.1.tar.gz", hash = "sha256:1df44baf0a0efa74d8c1dad7a00592b98bce7d19edded7f770da8dbc49922213"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0-py3-none-any.whl", hash = "sha256:c267e29bfa3fb5e2a06a208448ada6ed5557a2924979b8d79c970b45d227c668"},
|
||||
{file = "launchdarkly_server_sdk-9.15.0.tar.gz", hash = "sha256:f31441b74bc1a69c381db57c33116509e407a2612628ad6dff0a7dbb39d5020b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4728,14 +4728,14 @@ tests = ["coverage-conditional-plugin (>=0.9.0)", "portalocker[redis]", "pytest
|
||||
|
||||
[[package]]
|
||||
name = "postgrest"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "postgrest-2.27.3-py3-none-any.whl", hash = "sha256:ed79123af7127edd78d538bfe8351d277e45b1a36994a4dbf57ae27dde87a7b7"},
|
||||
{file = "postgrest-2.27.3.tar.gz", hash = "sha256:c2e2679addfc8eaab23197bad7ddaee6cbb4cbe8c483ebd2d2e5219543037cc3"},
|
||||
{file = "postgrest-2.28.0-py3-none-any.whl", hash = "sha256:7bca2f24dd1a1bf8a3d586c7482aba6cd41662da6733045fad585b63b7f7df75"},
|
||||
{file = "postgrest-2.28.0.tar.gz", hash = "sha256:c36b38646d25ea4255321d3d924ce70f8d20ec7799cb42c1221d6a818d4f6515"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6260,14 +6260,14 @@ all = ["numpy"]
|
||||
|
||||
[[package]]
|
||||
name = "realtime"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "realtime-2.27.3-py3-none-any.whl", hash = "sha256:f571115f86988e33c41c895cb3fba2eaa1b693aeaede3617288f44274ca90f43"},
|
||||
{file = "realtime-2.27.3.tar.gz", hash = "sha256:02b082243107656a5ef3fb63e8e2ab4c40bc199abb45adb8a42ed63f089a1041"},
|
||||
{file = "realtime-2.28.0-py3-none-any.whl", hash = "sha256:db1bd59bab9b1fcc9f9d3b1a073bed35bf4994d720e6751f10031a58d57a3836"},
|
||||
{file = "realtime-2.28.0.tar.gz", hash = "sha256:d18cedcebd6a8f22fcd509bc767f639761eb218b7b2b6f14fc4205b6259b50fc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7024,14 +7024,14 @@ full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart
|
||||
|
||||
[[package]]
|
||||
name = "storage3"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = "Supabase Storage client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "storage3-2.27.3-py3-none-any.whl", hash = "sha256:11a05b7da84bccabeeea12d940bca3760cf63fe6ca441868677335cfe4fdfbe0"},
|
||||
{file = "storage3-2.27.3.tar.gz", hash = "sha256:dc1a4a010cf36d5482c5cb6c1c28fc5f00e23284342b89e4ae43b5eae8501ddb"},
|
||||
{file = "storage3-2.28.0-py3-none-any.whl", hash = "sha256:ecb50efd2ac71dabbdf97e99ad346eafa630c4c627a8e5a138ceb5fbbadae716"},
|
||||
{file = "storage3-2.28.0.tar.gz", hash = "sha256:bc1d008aff67de7a0f2bd867baee7aadbcdb6f78f5a310b4f7a38e8c13c19865"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7091,35 +7091,35 @@ typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||
|
||||
[[package]]
|
||||
name = "supabase"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = "Supabase client for Python."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase-2.27.3-py3-none-any.whl", hash = "sha256:082a74642fcf9954693f1ce8c251baf23e4bda26ffdbc8dcd4c99c82e60d69ff"},
|
||||
{file = "supabase-2.27.3.tar.gz", hash = "sha256:5e5a348232ac4315c1032ddd687278f0b982465471f0cbb52bca7e6a66495ff3"},
|
||||
{file = "supabase-2.28.0-py3-none-any.whl", hash = "sha256:42776971c7d0ccca16034df1ab96a31c50228eb1eb19da4249ad2f756fc20272"},
|
||||
{file = "supabase-2.28.0.tar.gz", hash = "sha256:aea299aaab2a2eed3c57e0be7fc035c6807214194cce795a3575add20268ece1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.26,<0.29"
|
||||
postgrest = "2.27.3"
|
||||
realtime = "2.27.3"
|
||||
storage3 = "2.27.3"
|
||||
supabase-auth = "2.27.3"
|
||||
supabase-functions = "2.27.3"
|
||||
postgrest = "2.28.0"
|
||||
realtime = "2.28.0"
|
||||
storage3 = "2.28.0"
|
||||
supabase-auth = "2.28.0"
|
||||
supabase-functions = "2.28.0"
|
||||
yarl = ">=1.22.0"
|
||||
|
||||
[[package]]
|
||||
name = "supabase-auth"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = "Python Client Library for Supabase Auth"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_auth-2.27.3-py3-none-any.whl", hash = "sha256:82a4262eaad85383319d394dab0eea11fcf3ebd774062aef8ea3874ae2f02579"},
|
||||
{file = "supabase_auth-2.27.3.tar.gz", hash = "sha256:39894d4bc60b6f23b5cff4d0d7d4c1659e5d69563cadf014d4896f780ca8ca78"},
|
||||
{file = "supabase_auth-2.28.0-py3-none-any.whl", hash = "sha256:2ac85026cc285054c7fa6d41924f3a333e9ec298c013e5b5e1754039ba7caec9"},
|
||||
{file = "supabase_auth-2.28.0.tar.gz", hash = "sha256:2bb8f18ff39934e44b28f10918db965659f3735cd6fbfcc022fe0b82dbf8233e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7129,14 +7129,14 @@ pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
|
||||
[[package]]
|
||||
name = "supabase-functions"
|
||||
version = "2.27.3"
|
||||
version = "2.28.0"
|
||||
description = "Library for Supabase Functions"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "supabase_functions-2.27.3-py3-none-any.whl", hash = "sha256:9d14a931d49ede1c6cf5fbfceb11c44061535ba1c3f310f15384964d86a83d9e"},
|
||||
{file = "supabase_functions-2.27.3.tar.gz", hash = "sha256:e954f1646da8ca6e7e16accef58d0884a5f97b25956ee98e7d4927a210ed92f9"},
|
||||
{file = "supabase_functions-2.28.0-py3-none-any.whl", hash = "sha256:30bf2d586f8df285faf0621bb5d5bb3ec3157234fc820553ca156f009475e4ae"},
|
||||
{file = "supabase_functions-2.28.0.tar.gz", hash = "sha256:db3dddfc37aca5858819eb461130968473bd8c75bd284581013958526dac718b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -8440,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af"
|
||||
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"
|
||||
|
||||
@@ -65,7 +65,7 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
||||
sqlalchemy = "^2.0.40"
|
||||
strenum = "^0.4.9"
|
||||
stripe = "^11.5.0"
|
||||
supabase = "2.27.3"
|
||||
supabase = "2.28.0"
|
||||
tenacity = "^9.1.4"
|
||||
todoist-api-python = "^2.1.7"
|
||||
tweepy = "^4.16.0"
|
||||
@@ -116,6 +116,7 @@ ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
lint = "linter:lint"
|
||||
|
||||
@@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
@@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
@@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration:
|
||||
),
|
||||
]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
@@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
@@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration:
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
mock_store_db = MagicMock()
|
||||
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_store_listed_graphs = AsyncMock(
|
||||
return_value={"graph-123": mock_graph}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
patch.object(core, "store_db", return_value=mock_store_db),
|
||||
patch.object(core, "graph_db", return_value=mock_graph_db),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
mock_store_db.get_store_agents.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
@@ -707,7 +713,7 @@ class TestExtractUuidsFromText:
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
@@ -720,12 +726,10 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -743,20 +747,11 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -766,20 +761,13 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(
|
||||
side_effect=core.NotFoundError("Not found")
|
||||
)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
@@ -787,27 +775,20 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
"""Test that get_library_agent_by_graph_id is an alias."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
@@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
from backend.copilot.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
|
||||
@@ -37,7 +37,7 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "poetry run prisma generate && poetry run gen-prisma-stub && poetry run prisma migrate deploy"]
|
||||
command: ["sh", "-c", "prisma generate && python3 gen_prisma_types_stub.py && prisma migrate deploy"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
@@ -56,7 +56,7 @@ services:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"poetry run prisma migrate status | grep -q 'No pending migrations' || exit 1",
|
||||
"prisma migrate status | grep -q 'No pending migrations' || exit 1",
|
||||
]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
@@ -158,6 +158,41 @@ services:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
copilot_executor:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- app-network
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
websocket_server:
|
||||
build:
|
||||
context: ../
|
||||
|
||||
@@ -53,6 +53,12 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: executor
|
||||
|
||||
copilot_executor:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: copilot_executor
|
||||
|
||||
websocket_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
@@ -174,5 +180,6 @@ services:
|
||||
- deps
|
||||
- rest_server
|
||||
- executor
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import { BlockDetailsCard } from "./components/BlockDetailsCard/BlockDetailsCard";
|
||||
import { BlockOutputCard } from "./components/BlockOutputCard/BlockOutputCard";
|
||||
import { ErrorCard } from "./components/ErrorCard/ErrorCard";
|
||||
import { SetupRequirementsCard } from "./components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
@@ -11,6 +12,7 @@ import {
|
||||
getAnimationText,
|
||||
getRunBlockToolOutput,
|
||||
isRunBlockBlockOutput,
|
||||
isRunBlockDetailsOutput,
|
||||
isRunBlockErrorOutput,
|
||||
isRunBlockSetupRequirementsOutput,
|
||||
ToolIcon,
|
||||
@@ -41,6 +43,7 @@ export function RunBlockTool({ part }: Props) {
|
||||
part.state === "output-available" &&
|
||||
!!output &&
|
||||
(isRunBlockBlockOutput(output) ||
|
||||
isRunBlockDetailsOutput(output) ||
|
||||
isRunBlockSetupRequirementsOutput(output) ||
|
||||
isRunBlockErrorOutput(output));
|
||||
|
||||
@@ -58,6 +61,10 @@ export function RunBlockTool({ part }: Props) {
|
||||
<ToolAccordion {...getAccordionMeta(output)}>
|
||||
{isRunBlockBlockOutput(output) && <BlockOutputCard output={output} />}
|
||||
|
||||
{isRunBlockDetailsOutput(output) && (
|
||||
<BlockDetailsCard output={output} />
|
||||
)}
|
||||
|
||||
{isRunBlockSetupRequirementsOutput(output) && (
|
||||
<SetupRequirementsCard output={output} />
|
||||
)}
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { ResponseType } from "@/app/api/__generated__/models/responseType";
|
||||
import type { BlockDetailsResponse } from "../../helpers";
|
||||
import { BlockDetailsCard } from "./BlockDetailsCard";
|
||||
|
||||
const meta: Meta<typeof BlockDetailsCard> = {
|
||||
title: "Copilot/RunBlock/BlockDetailsCard",
|
||||
component: BlockDetailsCard,
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div style={{ maxWidth: 480 }}>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
const baseBlock: BlockDetailsResponse = {
|
||||
type: ResponseType.block_details,
|
||||
message:
|
||||
"Here are the details for the GetWeather block. Provide the required inputs to run it.",
|
||||
session_id: "session-123",
|
||||
user_authenticated: true,
|
||||
block: {
|
||||
id: "block-abc-123",
|
||||
name: "GetWeather",
|
||||
description: "Fetches current weather data for a given location.",
|
||||
inputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
location: {
|
||||
title: "Location",
|
||||
type: "string",
|
||||
description:
|
||||
"City name or coordinates (e.g. 'London' or '51.5,-0.1')",
|
||||
},
|
||||
units: {
|
||||
title: "Units",
|
||||
type: "string",
|
||||
description: "Temperature units: 'metric' or 'imperial'",
|
||||
},
|
||||
},
|
||||
required: ["location"],
|
||||
},
|
||||
outputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
temperature: {
|
||||
title: "Temperature",
|
||||
type: "number",
|
||||
description: "Current temperature in the requested units",
|
||||
},
|
||||
condition: {
|
||||
title: "Condition",
|
||||
type: "string",
|
||||
description: "Weather condition description (e.g. 'Sunny', 'Rain')",
|
||||
},
|
||||
},
|
||||
},
|
||||
credentials: [],
|
||||
},
|
||||
};
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
output: baseBlock,
|
||||
},
|
||||
};
|
||||
|
||||
export const InputsOnly: Story = {
|
||||
args: {
|
||||
output: {
|
||||
...baseBlock,
|
||||
message: "This block requires inputs. No outputs are defined.",
|
||||
block: {
|
||||
...baseBlock.block,
|
||||
outputs: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const OutputsOnly: Story = {
|
||||
args: {
|
||||
output: {
|
||||
...baseBlock,
|
||||
message: "This block has no required inputs.",
|
||||
block: {
|
||||
...baseBlock.block,
|
||||
inputs: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const ManyFields: Story = {
|
||||
args: {
|
||||
output: {
|
||||
...baseBlock,
|
||||
message: "Block with many input and output fields.",
|
||||
block: {
|
||||
...baseBlock.block,
|
||||
name: "SendEmail",
|
||||
description: "Sends an email via SMTP.",
|
||||
inputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
to: {
|
||||
title: "To",
|
||||
type: "string",
|
||||
description: "Recipient email address",
|
||||
},
|
||||
subject: {
|
||||
title: "Subject",
|
||||
type: "string",
|
||||
description: "Email subject line",
|
||||
},
|
||||
body: {
|
||||
title: "Body",
|
||||
type: "string",
|
||||
description: "Email body content",
|
||||
},
|
||||
cc: {
|
||||
title: "CC",
|
||||
type: "string",
|
||||
description: "CC recipients (comma-separated)",
|
||||
},
|
||||
bcc: {
|
||||
title: "BCC",
|
||||
type: "string",
|
||||
description: "BCC recipients (comma-separated)",
|
||||
},
|
||||
},
|
||||
required: ["to", "subject", "body"],
|
||||
},
|
||||
outputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
message_id: {
|
||||
title: "Message ID",
|
||||
type: "string",
|
||||
description: "Unique ID of the sent email",
|
||||
},
|
||||
status: {
|
||||
title: "Status",
|
||||
type: "string",
|
||||
description: "Delivery status",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const NoFieldDescriptions: Story = {
|
||||
args: {
|
||||
output: {
|
||||
...baseBlock,
|
||||
message: "Fields without descriptions.",
|
||||
block: {
|
||||
...baseBlock.block,
|
||||
name: "SimpleBlock",
|
||||
inputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
input_a: { title: "Input A", type: "string" },
|
||||
input_b: { title: "Input B", type: "number" },
|
||||
},
|
||||
required: ["input_a"],
|
||||
},
|
||||
outputs: {
|
||||
type: "object",
|
||||
properties: {
|
||||
result: { title: "Result", type: "string" },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,103 @@
|
||||
"use client";
|
||||
|
||||
import type { BlockDetailsResponse } from "../../helpers";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCard,
|
||||
ContentCardDescription,
|
||||
ContentCardTitle,
|
||||
ContentGrid,
|
||||
ContentMessage,
|
||||
} from "../../../../components/ToolAccordion/AccordionContent";
|
||||
|
||||
interface Props {
|
||||
output: BlockDetailsResponse;
|
||||
}
|
||||
|
||||
function SchemaFieldList({
|
||||
title,
|
||||
properties,
|
||||
required,
|
||||
}: {
|
||||
title: string;
|
||||
properties: Record<string, unknown>;
|
||||
required?: string[];
|
||||
}) {
|
||||
const entries = Object.entries(properties);
|
||||
if (entries.length === 0) return null;
|
||||
|
||||
const requiredSet = new Set(required ?? []);
|
||||
|
||||
return (
|
||||
<ContentCard>
|
||||
<ContentCardTitle className="text-xs">{title}</ContentCardTitle>
|
||||
<div className="mt-2 grid gap-2">
|
||||
{entries.map(([name, schema]) => {
|
||||
const field = schema as Record<string, unknown> | undefined;
|
||||
const fieldTitle =
|
||||
typeof field?.title === "string" ? field.title : name;
|
||||
const fieldType =
|
||||
typeof field?.type === "string" ? field.type : "unknown";
|
||||
const description =
|
||||
typeof field?.description === "string"
|
||||
? field.description
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<div key={name} className="rounded-xl border p-2">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<ContentCardTitle className="text-xs">
|
||||
{fieldTitle}
|
||||
</ContentCardTitle>
|
||||
<div className="flex gap-1">
|
||||
<ContentBadge>{fieldType}</ContentBadge>
|
||||
{requiredSet.has(name) && (
|
||||
<ContentBadge>Required</ContentBadge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{description && (
|
||||
<ContentCardDescription className="mt-1 text-xs">
|
||||
{description}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ContentCard>
|
||||
);
|
||||
}
|
||||
|
||||
export function BlockDetailsCard({ output }: Props) {
|
||||
const inputs = output.block.inputs as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
} | null;
|
||||
const outputs = output.block.outputs as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
} | null;
|
||||
|
||||
return (
|
||||
<ContentGrid>
|
||||
<ContentMessage>{output.message}</ContentMessage>
|
||||
|
||||
{inputs?.properties && Object.keys(inputs.properties).length > 0 && (
|
||||
<SchemaFieldList
|
||||
title="Inputs"
|
||||
properties={inputs.properties}
|
||||
required={inputs.required}
|
||||
/>
|
||||
)}
|
||||
|
||||
{outputs?.properties && Object.keys(outputs.properties).length > 0 && (
|
||||
<SchemaFieldList
|
||||
title="Outputs"
|
||||
properties={outputs.properties}
|
||||
required={outputs.required}
|
||||
/>
|
||||
)}
|
||||
</ContentGrid>
|
||||
);
|
||||
}
|
||||
@@ -10,18 +10,37 @@ import {
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
|
||||
|
||||
/** Block details returned on first run_block attempt (before input_data provided). */
|
||||
export interface BlockDetailsResponse {
|
||||
type: typeof ResponseType.block_details;
|
||||
message: string;
|
||||
session_id?: string | null;
|
||||
block: {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
inputs: Record<string, unknown>;
|
||||
outputs: Record<string, unknown>;
|
||||
credentials: unknown[];
|
||||
};
|
||||
user_authenticated: boolean;
|
||||
}
|
||||
|
||||
export interface RunBlockInput {
|
||||
block_id?: string;
|
||||
block_name?: string;
|
||||
input_data?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type RunBlockToolOutput =
|
||||
| SetupRequirementsResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
| ErrorResponse;
|
||||
|
||||
const RUN_BLOCK_OUTPUT_TYPES = new Set<string>([
|
||||
ResponseType.setup_requirements,
|
||||
ResponseType.block_details,
|
||||
ResponseType.block_output,
|
||||
ResponseType.error,
|
||||
]);
|
||||
@@ -35,6 +54,15 @@ export function isRunBlockSetupRequirementsOutput(
|
||||
);
|
||||
}
|
||||
|
||||
export function isRunBlockDetailsOutput(
|
||||
output: RunBlockToolOutput,
|
||||
): output is BlockDetailsResponse {
|
||||
return (
|
||||
output.type === ResponseType.block_details ||
|
||||
("block" in output && typeof output.block === "object")
|
||||
);
|
||||
}
|
||||
|
||||
export function isRunBlockBlockOutput(
|
||||
output: RunBlockToolOutput,
|
||||
): output is BlockOutputResponse {
|
||||
@@ -64,6 +92,7 @@ function parseOutput(output: unknown): RunBlockToolOutput | null {
|
||||
return output as RunBlockToolOutput;
|
||||
}
|
||||
if ("block_id" in output) return output as BlockOutputResponse;
|
||||
if ("block" in output) return output as BlockDetailsResponse;
|
||||
if ("setup_info" in output) return output as SetupRequirementsResponse;
|
||||
if ("error" in output || "details" in output)
|
||||
return output as ErrorResponse;
|
||||
@@ -84,17 +113,25 @@ export function getAnimationText(part: {
|
||||
output?: unknown;
|
||||
}): string {
|
||||
const input = part.input as RunBlockInput | undefined;
|
||||
const blockName = input?.block_name?.trim();
|
||||
const blockId = input?.block_id?.trim();
|
||||
const blockText = blockId ? ` "${blockId}"` : "";
|
||||
// Prefer block_name if available, otherwise fall back to block_id
|
||||
const blockText = blockName
|
||||
? ` "${blockName}"`
|
||||
: blockId
|
||||
? ` "${blockId}"`
|
||||
: "";
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Running the block${blockText}`;
|
||||
return `Running${blockText}`;
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (!output) return `Running the block${blockText}`;
|
||||
if (!output) return `Running${blockText}`;
|
||||
if (isRunBlockBlockOutput(output)) return `Ran "${output.block_name}"`;
|
||||
if (isRunBlockDetailsOutput(output))
|
||||
return `Details for "${output.block.name}"`;
|
||||
if (isRunBlockSetupRequirementsOutput(output)) {
|
||||
return `Setup needed for "${output.setup_info.agent_name}"`;
|
||||
}
|
||||
@@ -158,6 +195,21 @@ export function getAccordionMeta(output: RunBlockToolOutput): {
|
||||
};
|
||||
}
|
||||
|
||||
if (isRunBlockDetailsOutput(output)) {
|
||||
const inputKeys = Object.keys(
|
||||
(output.block.inputs as { properties?: Record<string, unknown> })
|
||||
?.properties ?? {},
|
||||
);
|
||||
return {
|
||||
icon,
|
||||
title: output.block.name,
|
||||
description:
|
||||
inputKeys.length > 0
|
||||
? `${inputKeys.length} input field${inputKeys.length === 1 ? "" : "s"} available`
|
||||
: output.message,
|
||||
};
|
||||
}
|
||||
|
||||
if (isRunBlockSetupRequirementsOutput(output)) {
|
||||
const missingCredsCount = Object.keys(
|
||||
(output.setup_info.user_readiness?.missing_credentials ?? {}) as Record<
|
||||
|
||||
@@ -1053,6 +1053,7 @@
|
||||
"$ref": "#/components/schemas/ClarificationNeededResponse"
|
||||
},
|
||||
{ "$ref": "#/components/schemas/BlockListResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockDetailsResponse" },
|
||||
{ "$ref": "#/components/schemas/BlockOutputResponse" },
|
||||
{ "$ref": "#/components/schemas/DocSearchResultsResponse" },
|
||||
{ "$ref": "#/components/schemas/DocPageResponse" },
|
||||
@@ -6958,6 +6959,58 @@
|
||||
"enum": ["run", "byte", "second"],
|
||||
"title": "BlockCostType"
|
||||
},
|
||||
"BlockDetails": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"inputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Inputs",
|
||||
"default": {}
|
||||
},
|
||||
"outputs": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Outputs",
|
||||
"default": {}
|
||||
},
|
||||
"credentials": {
|
||||
"items": { "$ref": "#/components/schemas/CredentialsMetaInput" },
|
||||
"type": "array",
|
||||
"title": "Credentials",
|
||||
"default": []
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "name", "description"],
|
||||
"title": "BlockDetails",
|
||||
"description": "Detailed block information."
|
||||
},
|
||||
"BlockDetailsResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "block_details"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"block": { "$ref": "#/components/schemas/BlockDetails" },
|
||||
"user_authenticated": {
|
||||
"type": "boolean",
|
||||
"title": "User Authenticated",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "block"],
|
||||
"title": "BlockDetailsResponse",
|
||||
"description": "Response for block details (first run_block attempt)."
|
||||
},
|
||||
"BlockInfo": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -7013,62 +7066,13 @@
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"categories": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Categories"
|
||||
},
|
||||
"input_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Input Schema"
|
||||
},
|
||||
"output_schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Output Schema"
|
||||
},
|
||||
"required_inputs": {
|
||||
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
||||
"type": "array",
|
||||
"title": "Required Inputs",
|
||||
"description": "List of required input fields for this block"
|
||||
}
|
||||
"description": { "type": "string", "title": "Description" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"categories",
|
||||
"input_schema",
|
||||
"output_schema"
|
||||
],
|
||||
"required": ["id", "name", "description"],
|
||||
"title": "BlockInfoSummary",
|
||||
"description": "Summary of a block for search results."
|
||||
},
|
||||
"BlockInputFieldInfo": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
"type": { "type": "string", "title": "Type" },
|
||||
"description": {
|
||||
"type": "string",
|
||||
"title": "Description",
|
||||
"default": ""
|
||||
},
|
||||
"required": {
|
||||
"type": "boolean",
|
||||
"title": "Required",
|
||||
"default": false
|
||||
},
|
||||
"default": { "anyOf": [{}, { "type": "null" }], "title": "Default" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["name", "type"],
|
||||
"title": "BlockInputFieldInfo",
|
||||
"description": "Information about a block input field."
|
||||
},
|
||||
"BlockListResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
@@ -7086,12 +7090,7 @@
|
||||
"title": "Blocks"
|
||||
},
|
||||
"count": { "type": "integer", "title": "Count" },
|
||||
"query": { "type": "string", "title": "Query" },
|
||||
"usage_hint": {
|
||||
"type": "string",
|
||||
"title": "Usage Hint",
|
||||
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema."
|
||||
}
|
||||
"query": { "type": "string", "title": "Query" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message", "blocks", "count", "query"],
|
||||
@@ -10484,6 +10483,7 @@
|
||||
"agent_saved",
|
||||
"clarification_needed",
|
||||
"block_list",
|
||||
"block_details",
|
||||
"block_output",
|
||||
"doc_search_results",
|
||||
"doc_page",
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { cjk } from "@streamdown/cjk";
|
||||
import { code } from "@/lib/streamdown-code-singleton";
|
||||
import { code } from "@streamdown/code";
|
||||
import { math } from "@streamdown/math";
|
||||
import { mermaid } from "@streamdown/mermaid";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
/**
|
||||
* Custom Streamdown code plugin with proper shiki singleton.
|
||||
*
|
||||
* Fixes SENTRY-1051: "@streamdown/code creates a new shiki highlighter per language,
|
||||
* causing "10 instances created" warnings and memory bloat.
|
||||
*
|
||||
* This plugin creates ONE highlighter and loads languages dynamically.
|
||||
*/
|
||||
|
||||
import {
|
||||
createHighlighter,
|
||||
bundledLanguages,
|
||||
type Highlighter,
|
||||
type BundledLanguage,
|
||||
type BundledTheme,
|
||||
} from "shiki";
|
||||
|
||||
// Types matching streamdown's expected interface
|
||||
interface HighlightToken {
|
||||
content: string;
|
||||
color?: string;
|
||||
bgColor?: string;
|
||||
htmlStyle?: Record<string, string>;
|
||||
htmlAttrs?: Record<string, string>;
|
||||
offset?: number;
|
||||
}
|
||||
|
||||
interface HighlightResult {
|
||||
tokens: HighlightToken[][];
|
||||
fg?: string;
|
||||
bg?: string;
|
||||
}
|
||||
|
||||
interface HighlightOptions {
|
||||
code: string;
|
||||
language: BundledLanguage;
|
||||
themes: [string, string];
|
||||
}
|
||||
|
||||
interface CodeHighlighterPlugin {
|
||||
name: "shiki";
|
||||
type: "code-highlighter";
|
||||
highlight: (
|
||||
options: HighlightOptions,
|
||||
callback?: (result: HighlightResult) => void
|
||||
) => HighlightResult | null;
|
||||
supportsLanguage: (language: BundledLanguage) => boolean;
|
||||
getSupportedLanguages: () => BundledLanguage[];
|
||||
getThemes: () => [BundledTheme, BundledTheme];
|
||||
}
|
||||
|
||||
// Singleton state
|
||||
let highlighterPromise: Promise<Highlighter> | null = null;
|
||||
let highlighterInstance: Highlighter | null = null;
|
||||
const loadedLanguages = new Set<string>();
|
||||
const pendingLanguages = new Map<string, Promise<void>>();
|
||||
|
||||
// Result cache (same as @streamdown/code)
|
||||
const resultCache = new Map<string, HighlightResult>();
|
||||
const pendingCallbacks = new Map<string, Set<(result: HighlightResult) => void>>();
|
||||
|
||||
// All supported languages
|
||||
const supportedLanguages = new Set(Object.keys(bundledLanguages));
|
||||
|
||||
// Cache key for results
|
||||
function getCacheKey(code: string, language: string, themes: [string, string]): string {
|
||||
const prefix = code.slice(0, 100);
|
||||
const suffix = code.length > 100 ? code.slice(-100) : "";
|
||||
return `${language}:${themes[0]}:${themes[1]}:${code.length}:${prefix}:${suffix}`;
|
||||
}
|
||||
|
||||
// Get or create the singleton highlighter
|
||||
async function getHighlighter(themes: [string, string]): Promise<Highlighter> {
|
||||
if (highlighterInstance) {
|
||||
return highlighterInstance;
|
||||
}
|
||||
|
||||
if (!highlighterPromise) {
|
||||
highlighterPromise = createHighlighter({
|
||||
themes: themes as [BundledTheme, BundledTheme],
|
||||
// Start with common languages pre-loaded for faster first render
|
||||
langs: ["javascript", "typescript", "python", "json", "html", "css", "bash", "markdown"],
|
||||
}).then((h: Highlighter) => {
|
||||
highlighterInstance = h;
|
||||
["javascript", "typescript", "python", "json", "html", "css", "bash", "markdown"].forEach(
|
||||
(l) => loadedLanguages.add(l)
|
||||
);
|
||||
return h;
|
||||
});
|
||||
}
|
||||
|
||||
return highlighterPromise;
|
||||
}
|
||||
|
||||
// Load a language dynamically
|
||||
async function ensureLanguageLoaded(
|
||||
highlighter: Highlighter,
|
||||
language: string
|
||||
): Promise<void> {
|
||||
if (loadedLanguages.has(language)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (pendingLanguages.has(language)) {
|
||||
return pendingLanguages.get(language);
|
||||
}
|
||||
|
||||
const loadPromise = highlighter
|
||||
.loadLanguage(language as BundledLanguage)
|
||||
.then(() => {
|
||||
loadedLanguages.add(language);
|
||||
pendingLanguages.delete(language);
|
||||
})
|
||||
.catch((err: Error) => {
|
||||
console.warn(`[streamdown-code-singleton] Failed to load language: ${language}`, err);
|
||||
pendingLanguages.delete(language);
|
||||
});
|
||||
|
||||
pendingLanguages.set(language, loadPromise);
|
||||
return loadPromise;
|
||||
}
|
||||
|
||||
// Shiki token types
|
||||
interface ShikiToken {
|
||||
content: string;
|
||||
color?: string;
|
||||
htmlStyle?: Record<string, string>;
|
||||
}
|
||||
|
||||
// Convert shiki tokens to streamdown format
|
||||
function convertTokens(
|
||||
shikiResult: ReturnType<Highlighter["codeToTokens"]>
|
||||
): HighlightResult {
|
||||
return {
|
||||
tokens: shikiResult.tokens.map((line: ShikiToken[]) =>
|
||||
line.map((token: ShikiToken) => ({
|
||||
content: token.content,
|
||||
color: token.color,
|
||||
htmlStyle: token.htmlStyle,
|
||||
}))
|
||||
),
|
||||
fg: shikiResult.fg,
|
||||
bg: shikiResult.bg,
|
||||
};
|
||||
}
|
||||
|
||||
export interface CodePluginOptions {
|
||||
themes?: [BundledTheme, BundledTheme];
|
||||
}
|
||||
|
||||
export function createCodePlugin(
|
||||
options: CodePluginOptions = {}
|
||||
): CodeHighlighterPlugin {
|
||||
const themes = options.themes ?? ["github-light", "github-dark"];
|
||||
|
||||
return {
|
||||
name: "shiki",
|
||||
type: "code-highlighter",
|
||||
|
||||
supportsLanguage(language: BundledLanguage): boolean {
|
||||
return supportedLanguages.has(language);
|
||||
},
|
||||
|
||||
getSupportedLanguages(): BundledLanguage[] {
|
||||
return Array.from(supportedLanguages) as BundledLanguage[];
|
||||
},
|
||||
|
||||
getThemes(): [BundledTheme, BundledTheme] {
|
||||
return themes as [BundledTheme, BundledTheme];
|
||||
},
|
||||
|
||||
highlight(
|
||||
{ code, language, themes: highlightThemes }: HighlightOptions,
|
||||
callback?: (result: HighlightResult) => void
|
||||
): HighlightResult | null {
|
||||
const cacheKey = getCacheKey(code, language, highlightThemes);
|
||||
|
||||
// Return cached result if available
|
||||
if (resultCache.has(cacheKey)) {
|
||||
return resultCache.get(cacheKey)!;
|
||||
}
|
||||
|
||||
// Register callback for async result
|
||||
if (callback) {
|
||||
if (!pendingCallbacks.has(cacheKey)) {
|
||||
pendingCallbacks.set(cacheKey, new Set());
|
||||
}
|
||||
pendingCallbacks.get(cacheKey)!.add(callback);
|
||||
}
|
||||
|
||||
// Start async highlighting
|
||||
getHighlighter(highlightThemes)
|
||||
.then(async (highlighter) => {
|
||||
// Ensure language is loaded
|
||||
const lang = supportedLanguages.has(language) ? language : "text";
|
||||
if (lang !== "text") {
|
||||
await ensureLanguageLoaded(highlighter, lang);
|
||||
}
|
||||
|
||||
// Highlight the code
|
||||
const effectiveLang = highlighter.getLoadedLanguages().includes(lang)
|
||||
? lang
|
||||
: "text";
|
||||
|
||||
const shikiResult = highlighter.codeToTokens(code, {
|
||||
lang: effectiveLang,
|
||||
themes: {
|
||||
light: highlightThemes[0] as BundledTheme,
|
||||
dark: highlightThemes[1] as BundledTheme,
|
||||
},
|
||||
});
|
||||
|
||||
const result = convertTokens(shikiResult);
|
||||
resultCache.set(cacheKey, result);
|
||||
|
||||
// Notify all pending callbacks
|
||||
const callbacks = pendingCallbacks.get(cacheKey);
|
||||
if (callbacks) {
|
||||
for (const cb of callbacks) {
|
||||
cb(result);
|
||||
}
|
||||
pendingCallbacks.delete(cacheKey);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error("[streamdown-code-singleton] Failed to highlight code:", err);
|
||||
pendingCallbacks.delete(cacheKey);
|
||||
});
|
||||
|
||||
// Return null while async loading
|
||||
return null;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Pre-configured plugin with default settings (drop-in replacement for @streamdown/code)
|
||||
export const code = createCodePlugin();
|
||||
Reference in New Issue
Block a user