mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-21 04:57:58 -05:00
Compare commits
25 Commits
make-old-w
...
ci/test-op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b33732fac1 | ||
|
|
62c2d1cdc7 | ||
|
|
93e5c40189 | ||
|
|
4ea2411fda | ||
|
|
95154f03e6 | ||
|
|
f4b3358cb3 | ||
|
|
48f8c70e6f | ||
|
|
3e1a8c800c | ||
|
|
bcf3a0cd9c | ||
|
|
4e0ae67067 | ||
|
|
32acf066d0 | ||
|
|
8268d919f5 | ||
|
|
c7063a46a6 | ||
|
|
3509db9ebd | ||
|
|
79534efa68 | ||
|
|
69d0c05017 | ||
|
|
521dbdc25f | ||
|
|
3b9abbcdbc | ||
|
|
e0cd070e4d | ||
|
|
b6b7b77ddd | ||
|
|
fc5cf113a7 | ||
|
|
9a5a041102 | ||
|
|
1137cfde48 | ||
|
|
da8e7405b0 | ||
|
|
7b6db6e260 |
50
.github/workflows/platform-backend-ci.yml
vendored
50
.github/workflows/platform-backend-ci.yml
vendored
@@ -32,7 +32,9 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.11", "3.12", "3.13"]
|
# Use Python 3.13 to match Docker image (see backend/Dockerfile)
|
||||||
|
# ClamAV tests moved to platform-backend-security-ci.yml (runs on merge to master)
|
||||||
|
python-version: ["3.13"]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
services:
|
services:
|
||||||
@@ -48,23 +50,6 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||||
clamav:
|
|
||||||
image: clamav/clamav-debian:latest
|
|
||||||
ports:
|
|
||||||
- 3310:3310
|
|
||||||
env:
|
|
||||||
CLAMAV_NO_FRESHCLAMD: false
|
|
||||||
CLAMD_CONF_StreamMaxLength: 50M
|
|
||||||
CLAMD_CONF_MaxFileSize: 100M
|
|
||||||
CLAMD_CONF_MaxScanSize: 100M
|
|
||||||
CLAMD_CONF_MaxThreads: 4
|
|
||||||
CLAMD_CONF_ReadTimeout: 300
|
|
||||||
options: >-
|
|
||||||
--health-cmd "clamdscan --version || exit 1"
|
|
||||||
--health-interval 30s
|
|
||||||
--health-timeout 10s
|
|
||||||
--health-retries 5
|
|
||||||
--health-start-period 180s
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -146,35 +131,6 @@ jobs:
|
|||||||
# outputs:
|
# outputs:
|
||||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||||
|
|
||||||
- name: Wait for ClamAV to be ready
|
|
||||||
run: |
|
|
||||||
echo "Waiting for ClamAV daemon to start..."
|
|
||||||
max_attempts=60
|
|
||||||
attempt=0
|
|
||||||
|
|
||||||
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
|
||||||
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
|
||||||
sleep 5
|
|
||||||
attempt=$((attempt+1))
|
|
||||||
done
|
|
||||||
|
|
||||||
if [ $attempt -eq $max_attempts ]; then
|
|
||||||
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
|
||||||
echo "Checking ClamAV service logs..."
|
|
||||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "ClamAV is ready!"
|
|
||||||
|
|
||||||
# Verify ClamAV is responsive
|
|
||||||
echo "Testing ClamAV connection..."
|
|
||||||
timeout 10 bash -c 'echo "PING" | nc localhost 3310' || {
|
|
||||||
echo "ClamAV is not responding to PING"
|
|
||||||
docker logs $(docker ps -q --filter "ancestor=clamav/clamav-debian:latest") 2>&1 | tail -50 || echo "No ClamAV container found"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
- name: Run Database Migrations
|
- name: Run Database Migrations
|
||||||
run: poetry run prisma migrate dev --name updates
|
run: poetry run prisma migrate dev --name updates
|
||||||
env:
|
env:
|
||||||
|
|||||||
145
.github/workflows/platform-backend-security-ci.yml
vendored
Normal file
145
.github/workflows/platform-backend-security-ci.yml
vendored
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
name: AutoGPT Platform - Backend Security CI
|
||||||
|
|
||||||
|
# This workflow runs ClamAV-dependent security tests.
|
||||||
|
# It only runs on merge to master to avoid the 3-5 minute ClamAV startup time on every PR.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master]
|
||||||
|
paths:
|
||||||
|
- "autogpt_platform/backend/**/file*.py"
|
||||||
|
- "autogpt_platform/backend/**/scan*.py"
|
||||||
|
- "autogpt_platform/backend/**/virus*.py"
|
||||||
|
- "autogpt_platform/backend/**/media*.py"
|
||||||
|
- ".github/workflows/platform-backend-security-ci.yml"
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ format('backend-security-ci-{0}', github.sha) }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: autogpt_platform/backend
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
security-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 15
|
||||||
|
|
||||||
|
services:
|
||||||
|
redis:
|
||||||
|
image: redis:latest
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
||||||
|
clamav:
|
||||||
|
image: clamav/clamav-debian:latest
|
||||||
|
ports:
|
||||||
|
- 3310:3310
|
||||||
|
env:
|
||||||
|
CLAMAV_NO_FRESHCLAMD: false
|
||||||
|
CLAMD_CONF_StreamMaxLength: 50M
|
||||||
|
CLAMD_CONF_MaxFileSize: 100M
|
||||||
|
CLAMD_CONF_MaxScanSize: 100M
|
||||||
|
CLAMD_CONF_MaxThreads: 4
|
||||||
|
CLAMD_CONF_ReadTimeout: 300
|
||||||
|
options: >-
|
||||||
|
--health-cmd "clamdscan --version || exit 1"
|
||||||
|
--health-interval 30s
|
||||||
|
--health-timeout 10s
|
||||||
|
--health-retries 5
|
||||||
|
--health-start-period 180s
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
|
- name: Set up Python 3.13
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.13"
|
||||||
|
|
||||||
|
- name: Setup Supabase
|
||||||
|
uses: supabase/setup-cli@v1
|
||||||
|
with:
|
||||||
|
version: 1.178.1
|
||||||
|
|
||||||
|
- name: Set up Python dependency cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pypoetry
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
run: |
|
||||||
|
HEAD_POETRY_VERSION=$(python ../../.github/workflows/scripts/get_package_version_from_lockfile.py poetry)
|
||||||
|
echo "Using Poetry version ${HEAD_POETRY_VERSION}"
|
||||||
|
curl -sSL https://install.python-poetry.org | POETRY_VERSION=$HEAD_POETRY_VERSION python3 -
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: poetry install
|
||||||
|
|
||||||
|
- name: Generate Prisma Client
|
||||||
|
run: poetry run prisma generate
|
||||||
|
|
||||||
|
- id: supabase
|
||||||
|
name: Start Supabase
|
||||||
|
working-directory: .
|
||||||
|
run: |
|
||||||
|
supabase init
|
||||||
|
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||||
|
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Wait for ClamAV to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for ClamAV daemon to start..."
|
||||||
|
max_attempts=60
|
||||||
|
attempt=0
|
||||||
|
|
||||||
|
until nc -z localhost 3310 || [ $attempt -eq $max_attempts ]; do
|
||||||
|
echo "ClamAV is unavailable - sleeping (attempt $((attempt+1))/$max_attempts)"
|
||||||
|
sleep 5
|
||||||
|
attempt=$((attempt+1))
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $attempt -eq $max_attempts ]; then
|
||||||
|
echo "ClamAV failed to start after $((max_attempts*5)) seconds"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "ClamAV is ready!"
|
||||||
|
|
||||||
|
- name: Run Database Migrations
|
||||||
|
run: poetry run prisma migrate dev --name updates
|
||||||
|
env:
|
||||||
|
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
|
||||||
|
- name: Run security-related tests
|
||||||
|
run: |
|
||||||
|
poetry run pytest -v \
|
||||||
|
backend/util/virus_scanner_test.py \
|
||||||
|
backend/util/file_test.py \
|
||||||
|
backend/server/v2/store/media_test.py \
|
||||||
|
-x
|
||||||
|
env:
|
||||||
|
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||||
|
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||||
|
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||||
|
JWT_VERIFY_KEY: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||||
|
REDIS_HOST: "localhost"
|
||||||
|
REDIS_PORT: "6379"
|
||||||
|
ENCRYPTION_KEY: "dvziYgz0KSK8FENhju0ZYi8-fRTfAdlz6YLhdB_jhNw="
|
||||||
|
CLAMAV_SERVICE_HOST: "localhost"
|
||||||
|
CLAMAV_SERVICE_PORT: "3310"
|
||||||
|
CLAMAV_SERVICE_ENABLED: "true"
|
||||||
|
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
PLAIN_OUTPUT: True
|
||||||
|
RUN_ENV: local
|
||||||
|
PORT: 8080
|
||||||
93
.github/workflows/platform-frontend-ci.yml
vendored
93
.github/workflows/platform-frontend-ci.yml
vendored
@@ -154,35 +154,78 @@ jobs:
|
|||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
# Docker image tar caching - loads images from cache in parallel for faster startup
|
||||||
|
- name: Set up Docker image cache
|
||||||
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: ~/docker-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: docker-images-frontend-${{ runner.os }}-${{ hashFiles('autogpt_platform/docker-compose.yml') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-buildx-frontend-test-
|
docker-images-frontend-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Load or pull Docker images
|
||||||
|
working-directory: autogpt_platform
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/docker-cache
|
||||||
|
|
||||||
|
# Define image list for easy maintenance
|
||||||
|
IMAGES=(
|
||||||
|
"redis:latest"
|
||||||
|
"rabbitmq:management"
|
||||||
|
"kong:2.8.1"
|
||||||
|
"supabase/gotrue:v2.170.0"
|
||||||
|
"supabase/postgres:15.8.1.049"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if any cached tar files exist
|
||||||
|
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||||
|
echo "Docker cache found, loading images in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
filename=$(echo "$image" | tr ':/' '--')
|
||||||
|
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||||
|
echo "Loading $image..."
|
||||||
|
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
echo "All cached images loaded"
|
||||||
|
else
|
||||||
|
echo "No Docker cache found, pulling images in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
docker pull "$image" &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
|
||||||
|
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||||
|
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||||
|
echo "Saving Docker images to cache in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
filename=$(echo "$image" | tr ':/' '--')
|
||||||
|
echo "Saving $image..."
|
||||||
|
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
echo "Docker image cache saved"
|
||||||
|
else
|
||||||
|
echo "Skipping cache save for PR/feature branch"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Docker images ready for use"
|
||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
NEXT_PUBLIC_PW_TEST=true docker compose -f ../docker-compose.yml up -d
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: 1
|
DOCKER_BUILDKIT: 1
|
||||||
BUILDX_CACHE_FROM: type=local,src=/tmp/.buildx-cache
|
|
||||||
BUILDX_CACHE_TO: type=local,dest=/tmp/.buildx-cache-new,mode=max
|
|
||||||
|
|
||||||
- name: Move cache
|
|
||||||
run: |
|
|
||||||
rm -rf /tmp/.buildx-cache
|
|
||||||
if [ -d "/tmp/.buildx-cache-new" ]; then
|
|
||||||
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Wait for services to be ready
|
- name: Wait for services to be ready
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 30 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
echo "Waiting for database to be ready..."
|
echo "Waiting for database to be ready..."
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
timeout 30 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||||
|
|
||||||
- name: Create E2E test data
|
- name: Create E2E test data
|
||||||
run: |
|
run: |
|
||||||
@@ -221,9 +264,27 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Install Browser 'chromium'
|
# Playwright browser caching - saves 30-60s when cache hits
|
||||||
|
- name: Get Playwright version
|
||||||
|
id: playwright-version
|
||||||
|
run: |
|
||||||
|
echo "version=$(pnpm list @playwright/test --json | jq -r '.[0].dependencies["@playwright/test"].version')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Cache Playwright browsers
|
||||||
|
uses: actions/cache@v4
|
||||||
|
id: playwright-cache
|
||||||
|
with:
|
||||||
|
path: ~/.cache/ms-playwright
|
||||||
|
key: playwright-${{ runner.os }}-${{ steps.playwright-version.outputs.version }}
|
||||||
|
|
||||||
|
- name: Install Playwright browsers
|
||||||
|
if: steps.playwright-cache.outputs.cache-hit != 'true'
|
||||||
run: pnpm playwright install --with-deps chromium
|
run: pnpm playwright install --with-deps chromium
|
||||||
|
|
||||||
|
- name: Install Playwright deps only (when cache hit)
|
||||||
|
if: steps.playwright-cache.outputs.cache-hit == 'true'
|
||||||
|
run: pnpm playwright install-deps chromium
|
||||||
|
|
||||||
- name: Run Playwright tests
|
- name: Run Playwright tests
|
||||||
run: pnpm test:no-build
|
run: pnpm test:no-build
|
||||||
|
|
||||||
|
|||||||
64
.github/workflows/platform-fullstack-ci.yml
vendored
64
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -83,6 +83,66 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cp ../backend/.env.default ../backend/.env
|
cp ../backend/.env.default ../backend/.env
|
||||||
|
|
||||||
|
# Docker image tar caching - loads images from cache in parallel for faster startup
|
||||||
|
- name: Set up Docker image cache
|
||||||
|
id: docker-cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/docker-cache
|
||||||
|
key: docker-images-fullstack-${{ runner.os }}-${{ hashFiles('autogpt_platform/docker-compose.yml') }}
|
||||||
|
restore-keys: |
|
||||||
|
docker-images-fullstack-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Load or pull Docker images
|
||||||
|
working-directory: autogpt_platform
|
||||||
|
run: |
|
||||||
|
mkdir -p ~/docker-cache
|
||||||
|
|
||||||
|
# Define image list for easy maintenance
|
||||||
|
IMAGES=(
|
||||||
|
"redis:latest"
|
||||||
|
"rabbitmq:management"
|
||||||
|
"kong:2.8.1"
|
||||||
|
"supabase/gotrue:v2.170.0"
|
||||||
|
"supabase/postgres:15.8.1.049"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if any cached tar files exist
|
||||||
|
if ls ~/docker-cache/*.tar 1> /dev/null 2>&1; then
|
||||||
|
echo "Docker cache found, loading images in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
filename=$(echo "$image" | tr ':/' '--')
|
||||||
|
if [ -f ~/docker-cache/${filename}.tar ]; then
|
||||||
|
echo "Loading $image..."
|
||||||
|
docker load -i ~/docker-cache/${filename}.tar || echo "Warning: Failed to load $image from cache" &
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
echo "All cached images loaded"
|
||||||
|
else
|
||||||
|
echo "No Docker cache found, pulling images in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
docker pull "$image" &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
|
||||||
|
# Only save cache on main branches (not PRs) to avoid cache pollution
|
||||||
|
if [[ "${{ github.ref }}" == "refs/heads/master" ]] || [[ "${{ github.ref }}" == "refs/heads/dev" ]]; then
|
||||||
|
echo "Saving Docker images to cache in parallel..."
|
||||||
|
for image in "${IMAGES[@]}"; do
|
||||||
|
filename=$(echo "$image" | tr ':/' '--')
|
||||||
|
echo "Saving $image..."
|
||||||
|
docker save -o ~/docker-cache/${filename}.tar "$image" || echo "Warning: Failed to save $image" &
|
||||||
|
done
|
||||||
|
wait
|
||||||
|
echo "Docker image cache saved"
|
||||||
|
else
|
||||||
|
echo "Skipping cache save for PR/feature branch"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Docker images ready for use"
|
||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
||||||
@@ -104,9 +164,9 @@ jobs:
|
|||||||
- name: Wait for services to be ready
|
- name: Wait for services to be ready
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for rest_server to be ready..."
|
echo "Waiting for rest_server to be ready..."
|
||||||
timeout 60 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
timeout 30 sh -c 'until curl -f http://localhost:8006/health 2>/dev/null; do sleep 2; done' || echo "Rest server health check timeout, continuing..."
|
||||||
echo "Waiting for database to be ready..."
|
echo "Waiting for database to be ready..."
|
||||||
timeout 60 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
timeout 30 sh -c 'until docker compose -f ../docker-compose.yml exec -T db pg_isready -U postgres 2>/dev/null; do sleep 2; done' || echo "Database ready check timeout, continuing..."
|
||||||
|
|
||||||
- name: Generate API queries
|
- name: Generate API queries
|
||||||
run: pnpm generate:api:force
|
run: pnpm generate:api:force
|
||||||
|
|||||||
157
autogpt_platform/backend/backend/integrations/embeddings.py
Normal file
157
autogpt_platform/backend/backend/integrations/embeddings.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Embedding service for generating text embeddings using OpenAI.
|
||||||
|
|
||||||
|
Used for vector-based semantic search in the store.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
# Using text-embedding-3-small (1536 dimensions) for compatibility with pgvector indexes
|
||||||
|
# pgvector IVFFlat/HNSW indexes have dimension limits (2000 for IVFFlat, varies for HNSW)
|
||||||
|
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
EMBEDDING_DIMENSIONS = 1536
|
||||||
|
|
||||||
|
# Input validation limits
|
||||||
|
# OpenAI text-embedding-3-large supports up to 8191 tokens (~32k chars)
|
||||||
|
# We set a conservative limit to prevent abuse
|
||||||
|
MAX_TEXT_LENGTH = 10000 # characters
|
||||||
|
MAX_BATCH_SIZE = 100 # maximum texts per batch request
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
"""Service for generating text embeddings using OpenAI.
|
||||||
|
|
||||||
|
The service can be created without an API key - the key is validated
|
||||||
|
only when the client property is first accessed. This allows the service
|
||||||
|
to be instantiated at module load time without requiring configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
|
settings = Settings()
|
||||||
|
self.api_key = (
|
||||||
|
api_key
|
||||||
|
or settings.secrets.openai_internal_api_key
|
||||||
|
or settings.secrets.openai_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def client(self) -> openai.AsyncOpenAI:
|
||||||
|
"""Lazily create the OpenAI client, raising if no API key is configured."""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key not configured. "
|
||||||
|
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
return openai.AsyncOpenAI(api_key=self.api_key)
|
||||||
|
|
||||||
|
async def generate_embedding(self, text: str) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding for a single text string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to generate an embedding for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of floats representing the embedding vector.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the text is empty or exceeds maximum length.
|
||||||
|
openai.APIError: If the OpenAI API call fails.
|
||||||
|
"""
|
||||||
|
# Input validation
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError("Text cannot be empty")
|
||||||
|
if len(text) > MAX_TEXT_LENGTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Text exceeds maximum length of {MAX_TEXT_LENGTH} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
input=text,
|
||||||
|
dimensions=EMBEDDING_DIMENSIONS,
|
||||||
|
)
|
||||||
|
if not response.data:
|
||||||
|
raise ValueError("OpenAI API returned empty embedding data")
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple texts (batch).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to generate embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors, one per input text.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any text is invalid or batch size exceeds limit.
|
||||||
|
openai.APIError: If the OpenAI API call fails.
|
||||||
|
"""
|
||||||
|
# Input validation
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Texts list cannot be empty")
|
||||||
|
if len(texts) > MAX_BATCH_SIZE:
|
||||||
|
raise ValueError(f"Batch size exceeds maximum of {MAX_BATCH_SIZE} texts")
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
if not text or not text.strip():
|
||||||
|
raise ValueError(f"Text at index {i} cannot be empty")
|
||||||
|
if len(text) > MAX_TEXT_LENGTH:
|
||||||
|
raise ValueError(
|
||||||
|
f"Text at index {i} exceeds maximum length of {MAX_TEXT_LENGTH} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
input=texts,
|
||||||
|
dimensions=EMBEDDING_DIMENSIONS,
|
||||||
|
)
|
||||||
|
# Sort by index to ensure correct ordering
|
||||||
|
sorted_data = sorted(response.data, key=lambda x: x.index)
|
||||||
|
return [item.embedding for item in sorted_data]
|
||||||
|
|
||||||
|
|
||||||
|
def create_search_text(name: str, sub_heading: str, description: str) -> str:
|
||||||
|
"""
|
||||||
|
Combine fields into searchable text for embedding.
|
||||||
|
|
||||||
|
This creates a single text string from the agent's name, sub-heading,
|
||||||
|
and description, which is then converted to an embedding vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The agent name.
|
||||||
|
sub_heading: The agent sub-heading/tagline.
|
||||||
|
description: The agent description.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A single string combining all non-empty fields.
|
||||||
|
"""
|
||||||
|
parts = [name or "", sub_heading or "", description or ""]
|
||||||
|
# filter(None, parts) removes empty strings since empty string is falsy
|
||||||
|
return " ".join(filter(None, parts)).strip()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_embedding_service() -> EmbeddingService:
|
||||||
|
"""
|
||||||
|
Get or create the embedding service singleton.
|
||||||
|
|
||||||
|
Uses functools.cache for thread-safe lazy initialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shared EmbeddingService instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If OpenAI API key is not configured (when generating embeddings).
|
||||||
|
"""
|
||||||
|
return EmbeddingService()
|
||||||
235
autogpt_platform/backend/backend/integrations/embeddings_test.py
Normal file
235
autogpt_platform/backend/backend/integrations/embeddings_test.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""Tests for the embedding service.
|
||||||
|
|
||||||
|
This module tests:
|
||||||
|
- create_search_text utility function
|
||||||
|
- EmbeddingService input validation
|
||||||
|
- EmbeddingService API interaction (mocked)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.integrations.embeddings import (
|
||||||
|
EMBEDDING_DIMENSIONS,
|
||||||
|
MAX_BATCH_SIZE,
|
||||||
|
MAX_TEXT_LENGTH,
|
||||||
|
EmbeddingService,
|
||||||
|
create_search_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSearchText:
|
||||||
|
"""Tests for the create_search_text utility function."""
|
||||||
|
|
||||||
|
def test_combines_all_fields(self):
|
||||||
|
result = create_search_text("Agent Name", "A cool agent", "Does amazing things")
|
||||||
|
assert result == "Agent Name A cool agent Does amazing things"
|
||||||
|
|
||||||
|
def test_handles_empty_name(self):
|
||||||
|
result = create_search_text("", "Sub heading", "Description")
|
||||||
|
assert result == "Sub heading Description"
|
||||||
|
|
||||||
|
def test_handles_empty_sub_heading(self):
|
||||||
|
result = create_search_text("Name", "", "Description")
|
||||||
|
assert result == "Name Description"
|
||||||
|
|
||||||
|
def test_handles_empty_description(self):
|
||||||
|
result = create_search_text("Name", "Sub heading", "")
|
||||||
|
assert result == "Name Sub heading"
|
||||||
|
|
||||||
|
def test_handles_all_empty(self):
|
||||||
|
result = create_search_text("", "", "")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_handles_none_values(self):
|
||||||
|
# The function expects strings but should handle None gracefully
|
||||||
|
result = create_search_text(None, None, None) # type: ignore
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_preserves_content_strips_outer_whitespace(self):
|
||||||
|
# The function joins parts and strips the outer result
|
||||||
|
# Internal whitespace in each part is preserved
|
||||||
|
result = create_search_text(" Name ", " Sub ", " Desc ")
|
||||||
|
# Each part is joined with space, then outer strip applied
|
||||||
|
assert result == "Name Sub Desc"
|
||||||
|
|
||||||
|
def test_handles_only_whitespace(self):
|
||||||
|
# Parts that are only whitespace become empty after filter
|
||||||
|
result = create_search_text(" ", " ", " ")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingServiceValidation:
|
||||||
|
"""Tests for EmbeddingService input validation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self):
|
||||||
|
"""Mock settings with a test API key."""
|
||||||
|
with patch("backend.integrations.embeddings.Settings") as mock:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.secrets.openai_internal_api_key = "test-api-key"
|
||||||
|
mock_instance.secrets.openai_api_key = ""
|
||||||
|
mock.return_value = mock_instance
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(self, mock_settings):
|
||||||
|
"""Create an EmbeddingService instance with mocked settings."""
|
||||||
|
service = EmbeddingService()
|
||||||
|
# Inject a mock client by setting the cached_property directly
|
||||||
|
service.__dict__["client"] = MagicMock()
|
||||||
|
return service
|
||||||
|
|
||||||
|
def test_client_access_requires_api_key(self):
|
||||||
|
"""Test that accessing client fails without an API key."""
|
||||||
|
with patch("backend.integrations.embeddings.Settings") as mock:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.secrets.openai_internal_api_key = ""
|
||||||
|
mock_instance.secrets.openai_api_key = ""
|
||||||
|
mock.return_value = mock_instance
|
||||||
|
|
||||||
|
# Service creation should succeed
|
||||||
|
service = EmbeddingService()
|
||||||
|
|
||||||
|
# But accessing client should fail
|
||||||
|
with pytest.raises(ValueError, match="OpenAI API key not configured"):
|
||||||
|
_ = service.client
|
||||||
|
|
||||||
|
def test_init_accepts_explicit_api_key(self):
|
||||||
|
"""Test that explicit API key overrides settings."""
|
||||||
|
with patch("backend.integrations.embeddings.Settings") as mock:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.secrets.openai_internal_api_key = ""
|
||||||
|
mock_instance.secrets.openai_api_key = ""
|
||||||
|
mock.return_value = mock_instance
|
||||||
|
|
||||||
|
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
|
||||||
|
service = EmbeddingService(api_key="explicit-key")
|
||||||
|
assert service.api_key == "explicit-key"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embedding_empty_text(self, service):
|
||||||
|
"""Test that empty text raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="Text cannot be empty"):
|
||||||
|
await service.generate_embedding("")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embedding_whitespace_only(self, service):
|
||||||
|
"""Test that whitespace-only text raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="Text cannot be empty"):
|
||||||
|
await service.generate_embedding(" ")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embedding_exceeds_max_length(self, service):
|
||||||
|
"""Test that text exceeding max length raises ValueError."""
|
||||||
|
long_text = "a" * (MAX_TEXT_LENGTH + 1)
|
||||||
|
with pytest.raises(ValueError, match="exceeds maximum length"):
|
||||||
|
await service.generate_embedding(long_text)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_empty_list(self, service):
|
||||||
|
"""Test that empty list raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="Texts list cannot be empty"):
|
||||||
|
await service.generate_embeddings([])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_exceeds_batch_size(self, service):
|
||||||
|
"""Test that batch exceeding max size raises ValueError."""
|
||||||
|
texts = ["text"] * (MAX_BATCH_SIZE + 1)
|
||||||
|
with pytest.raises(ValueError, match="Batch size exceeds maximum"):
|
||||||
|
await service.generate_embeddings(texts)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_empty_text_in_batch(self, service):
|
||||||
|
"""Test that empty text in batch raises ValueError with index."""
|
||||||
|
with pytest.raises(ValueError, match="Text at index 1 cannot be empty"):
|
||||||
|
await service.generate_embeddings(["valid", "", "also valid"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_long_text_in_batch(self, service):
|
||||||
|
"""Test that long text in batch raises ValueError with index."""
|
||||||
|
long_text = "a" * (MAX_TEXT_LENGTH + 1)
|
||||||
|
with pytest.raises(ValueError, match="Text at index 2 exceeds maximum length"):
|
||||||
|
await service.generate_embeddings(["short", "also short", long_text])
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingServiceAPI:
|
||||||
|
"""Tests for EmbeddingService API interaction."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client(self):
|
||||||
|
"""Create a mock OpenAI client."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings = MagicMock()
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service_with_mock_client(self, mock_openai_client):
|
||||||
|
"""Create an EmbeddingService with a mocked OpenAI client."""
|
||||||
|
with patch("backend.integrations.embeddings.Settings") as mock_settings:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_instance.secrets.openai_internal_api_key = "test-key"
|
||||||
|
mock_instance.secrets.openai_api_key = ""
|
||||||
|
mock_settings.return_value = mock_instance
|
||||||
|
|
||||||
|
service = EmbeddingService()
|
||||||
|
# Inject mock client by setting the cached_property directly
|
||||||
|
service.__dict__["client"] = mock_openai_client
|
||||||
|
return service, mock_openai_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embedding_success(self, service_with_mock_client):
|
||||||
|
"""Test successful embedding generation."""
|
||||||
|
service, mock_client = service_with_mock_client
|
||||||
|
|
||||||
|
# Create mock response
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock(embedding=mock_embedding)]
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await service.generate_embedding("test text")
|
||||||
|
|
||||||
|
assert result == mock_embedding
|
||||||
|
mock_client.embeddings.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_success(self, service_with_mock_client):
|
||||||
|
"""Test successful batch embedding generation."""
|
||||||
|
service, mock_client = service_with_mock_client
|
||||||
|
|
||||||
|
# Create mock response with multiple embeddings
|
||||||
|
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [
|
||||||
|
MagicMock(embedding=mock_embeddings[0], index=0),
|
||||||
|
MagicMock(embedding=mock_embeddings[1], index=1),
|
||||||
|
]
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await service.generate_embeddings(["text1", "text2"])
|
||||||
|
|
||||||
|
assert result == mock_embeddings
|
||||||
|
mock_client.embeddings.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_embeddings_preserves_order(self, service_with_mock_client):
|
||||||
|
"""Test that batch embeddings are returned in correct order even if API returns out of order."""
|
||||||
|
service, mock_client = service_with_mock_client
|
||||||
|
|
||||||
|
# Create mock response with embeddings out of order
|
||||||
|
mock_embeddings = [[0.1] * EMBEDDING_DIMENSIONS, [0.2] * EMBEDDING_DIMENSIONS]
|
||||||
|
mock_response = MagicMock()
|
||||||
|
# Return in reverse order
|
||||||
|
mock_response.data = [
|
||||||
|
MagicMock(embedding=mock_embeddings[1], index=1),
|
||||||
|
MagicMock(embedding=mock_embeddings[0], index=0),
|
||||||
|
]
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
result = await service.generate_embeddings(["text1", "text2"])
|
||||||
|
|
||||||
|
# Should be sorted by index
|
||||||
|
assert result[0] == mock_embeddings[0]
|
||||||
|
assert result[1] == mock_embeddings[1]
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -32,7 +33,12 @@ from backend.util.exceptions import NotFoundError
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
config = backend.server.v2.chat.config.ChatConfig()
|
config = backend.server.v2.chat.config.ChatConfig()
|
||||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_openai_client() -> AsyncOpenAI:
|
||||||
|
"""Lazily create the OpenAI client singleton."""
|
||||||
|
return AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(
|
async def create_chat_session(
|
||||||
@@ -355,7 +361,7 @@ async def _stream_chat_chunks(
|
|||||||
logger.info("Creating OpenAI chat completion stream...")
|
logger.info("Creating OpenAI chat completion stream...")
|
||||||
|
|
||||||
# Create the stream with proper types
|
# Create the stream with proper types
|
||||||
stream = await client.chat.completions.create(
|
stream = await get_openai_client().chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=session.to_openai_messages(),
|
messages=session.to_openai_messages(),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
Script to backfill embeddings for existing store listing versions.
|
||||||
|
|
||||||
|
This script should be run after the migration to add the embedding column
|
||||||
|
to populate embeddings for all existing store listing versions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
poetry run python -m backend.server.v2.store.backfill_embeddings
|
||||||
|
poetry run python -m backend.server.v2.store.backfill_embeddings --dry-run
|
||||||
|
poetry run python -m backend.server.v2.store.backfill_embeddings --batch-size 25
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from backend.data.db import connect, disconnect, query_raw_with_schema
|
||||||
|
from backend.integrations.embeddings import create_search_text, get_embedding_service
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default batch size for processing
|
||||||
|
DEFAULT_BATCH_SIZE = 50
|
||||||
|
|
||||||
|
# Delay between batches to avoid rate limits (seconds)
|
||||||
|
BATCH_DELAY_SECONDS = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
async def backfill_embeddings(
|
||||||
|
dry_run: bool = False,
|
||||||
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Backfill embeddings for all store listing versions without embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dry_run: If True, don't make any changes, just report what would be done.
|
||||||
|
batch_size: Number of versions to process in each batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (processed_count, error_count)
|
||||||
|
"""
|
||||||
|
await connect()
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_service = get_embedding_service()
|
||||||
|
|
||||||
|
# Get all versions without embeddings
|
||||||
|
versions = await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
SELECT id, name, "subHeading", description
|
||||||
|
FROM {schema_prefix}"StoreListingVersion"
|
||||||
|
WHERE embedding IS NULL
|
||||||
|
ORDER BY "createdAt" DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
total = len(versions)
|
||||||
|
logger.info(f"Found {total} versions without embeddings")
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
logger.info("Dry run mode - no changes will be made")
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
logger.info("No versions need embeddings")
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
for i in range(0, total, batch_size):
|
||||||
|
batch = versions[i : i + batch_size]
|
||||||
|
batch_num = (i // batch_size) + 1
|
||||||
|
total_batches = (total + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
logger.info(f"Processing batch {batch_num}/{total_batches}")
|
||||||
|
|
||||||
|
for version in batch:
|
||||||
|
version_id = version["id"]
|
||||||
|
try:
|
||||||
|
search_text = create_search_text(
|
||||||
|
version["name"] or "",
|
||||||
|
version["subHeading"] or "",
|
||||||
|
version["description"] or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not search_text:
|
||||||
|
logger.warning(f"Skipping {version_id} - no searchable text")
|
||||||
|
continue
|
||||||
|
|
||||||
|
embedding = await embedding_service.generate_embedding(search_text)
|
||||||
|
embedding_str = "[" + ",".join(map(str, embedding)) + "]"
|
||||||
|
|
||||||
|
await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
UPDATE {schema_prefix}"StoreListingVersion"
|
||||||
|
SET embedding = $1::vector
|
||||||
|
WHERE id = $2
|
||||||
|
""",
|
||||||
|
embedding_str,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {version_id}: {e}")
|
||||||
|
errors += 1
|
||||||
|
|
||||||
|
logger.info(f"Progress: {processed}/{total} processed, {errors} errors")
|
||||||
|
|
||||||
|
# Rate limit: wait between batches to avoid hitting API limits
|
||||||
|
if i + batch_size < total:
|
||||||
|
await asyncio.sleep(BATCH_DELAY_SECONDS)
|
||||||
|
|
||||||
|
logger.info(f"Backfill complete: {processed} processed, {errors} errors")
|
||||||
|
return (processed, errors)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Backfill embeddings for store listing versions"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Don't make any changes, just report what would be done",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZE,
|
||||||
|
help=f"Number of versions to process in each batch (default: {DEFAULT_BATCH_SIZE})",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
processed, errors = asyncio.run(
|
||||||
|
backfill_embeddings(dry_run=args.dry_run, batch_size=args.batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors > 0:
|
||||||
|
logger.warning(f"Completed with {errors} errors")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
logger.info("Completed successfully")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -27,8 +27,9 @@ async def _get_cached_store_agents(
|
|||||||
category: str | None,
|
category: str | None,
|
||||||
page: int,
|
page: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||||
):
|
):
|
||||||
"""Cached helper to get store agents."""
|
"""Cached helper to get store agents with hybrid search support."""
|
||||||
return await backend.server.v2.store.db.get_store_agents(
|
return await backend.server.v2.store.db.get_store_agents(
|
||||||
featured=featured,
|
featured=featured,
|
||||||
creators=[creator] if creator else None,
|
creators=[creator] if creator else None,
|
||||||
@@ -37,6 +38,7 @@ async def _get_cached_store_agents(
|
|||||||
category=category,
|
category=category,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
filter_mode=filter_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from backend.data.notifications import (
|
|||||||
AgentRejectionData,
|
AgentRejectionData,
|
||||||
NotificationEventModel,
|
NotificationEventModel,
|
||||||
)
|
)
|
||||||
|
from backend.integrations.embeddings import create_search_text, get_embedding_service
|
||||||
from backend.notifications.notifications import queue_notification_async
|
from backend.notifications.notifications import queue_notification_async
|
||||||
from backend.util.exceptions import DatabaseError
|
from backend.util.exceptions import DatabaseError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
@@ -38,6 +39,25 @@ settings = Settings()
|
|||||||
DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
||||||
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
DEFAULT_ADMIN_EMAIL = "admin@autogpt.co"
|
||||||
|
|
||||||
|
# Minimum similarity threshold for vector search results
|
||||||
|
# Cosine similarity ranges from -1 to 1, where 1 is identical
|
||||||
|
# 0.4 filters loosely related results while keeping semantically relevant ones
|
||||||
|
VECTOR_SEARCH_SIMILARITY_THRESHOLD = 0.4
|
||||||
|
|
||||||
|
# Minimum relevance threshold for BM25 full-text search results
|
||||||
|
# ts_rank_cd returns values typically in range 0-1 (can exceed 1 for exact matches)
|
||||||
|
# 0.05 allows partial keyword matches
|
||||||
|
BM25_RELEVANCE_THRESHOLD = 0.05
|
||||||
|
|
||||||
|
# RRF constant (k) - standard value that balances influence of top vs lower ranks
|
||||||
|
# Higher k values reduce the influence of high-ranking items
|
||||||
|
RRF_K = 60
|
||||||
|
|
||||||
|
# Minimum RRF score threshold for combined mode
|
||||||
|
# Filters out results that rank poorly across all signals
|
||||||
|
# For reference: rank #1 in all = ~0.041, rank #100 in all = ~0.016
|
||||||
|
RRF_SCORE_THRESHOLD = 0.02
|
||||||
|
|
||||||
|
|
||||||
async def get_store_agents(
|
async def get_store_agents(
|
||||||
featured: bool = False,
|
featured: bool = False,
|
||||||
@@ -47,64 +67,223 @@ async def get_store_agents(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
|
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||||
"""
|
"""
|
||||||
Get PUBLIC store agents from the StoreAgent view
|
Get PUBLIC store agents from the StoreAgent view.
|
||||||
|
|
||||||
|
When search_query is provided, uses hybrid search combining:
|
||||||
|
- BM25 full-text search (lexical matching via PostgreSQL tsvector)
|
||||||
|
- Vector semantic similarity (meaning-based matching via pgvector)
|
||||||
|
- Popularity signal (run counts as PageRank proxy)
|
||||||
|
|
||||||
|
Results are ranked using Reciprocal Rank Fusion (RRF).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
featured: Filter to only show featured agents.
|
||||||
|
creators: Filter agents by creator usernames.
|
||||||
|
sorted_by: Sort agents by "runs", "rating", "name", or "updated_at".
|
||||||
|
search_query: Search query for hybrid search.
|
||||||
|
category: Filter agents by category.
|
||||||
|
page: Page number for pagination.
|
||||||
|
page_size: Number of agents per page.
|
||||||
|
filter_mode: Controls how results are filtered when searching:
|
||||||
|
- "strict": Must match BOTH BM25 AND vector thresholds
|
||||||
|
- "permissive": Must match EITHER BM25 OR vector threshold
|
||||||
|
- "combined": No threshold filtering, rely on RRF score (default)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StoreAgentsResponse with paginated list of agents.
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
f"Getting store agents. featured={featured}, creators={creators}, "
|
||||||
|
f"sorted_by={sorted_by}, search={search_query}, category={category}, "
|
||||||
|
f"page={page}, filter_mode={filter_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If search_query is provided, use full-text search
|
# If search_query is provided, use hybrid search (BM25 + vector + popularity)
|
||||||
if search_query:
|
if search_query:
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Whitelist allowed order_by columns
|
# Try to generate embedding for vector search
|
||||||
ALLOWED_ORDER_BY = {
|
# Falls back to BM25-only if embedding service is not available
|
||||||
"rating": "rating DESC, rank DESC",
|
query_embedding: list[float] | None = None
|
||||||
"runs": "runs DESC, rank DESC",
|
try:
|
||||||
"name": "agent_name ASC, rank ASC",
|
embedding_service = get_embedding_service()
|
||||||
"updated_at": "updated_at DESC, rank DESC",
|
query_embedding = await embedding_service.generate_embedding(
|
||||||
}
|
search_query
|
||||||
|
)
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
# Embedding service not configured or failed - use BM25 only
|
||||||
|
logger.warning(f"Embedding generation failed, using BM25 only: {e}")
|
||||||
|
|
||||||
# Validate and get order clause
|
# Convert embedding to PostgreSQL array format (or None for BM25-only)
|
||||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
embedding_str = (
|
||||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
"[" + ",".join(map(str, query_embedding)) + "]"
|
||||||
else:
|
if query_embedding
|
||||||
order_by_clause = "updated_at DESC, rank DESC"
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Build WHERE conditions and parameters list
|
# Build WHERE conditions and parameters list
|
||||||
|
# When embedding is not available (no OpenAI key), $1 will be NULL
|
||||||
where_parts: list[str] = []
|
where_parts: list[str] = []
|
||||||
params: list[typing.Any] = [search_query] # $1 - search term
|
params: list[typing.Any] = [embedding_str] # $1 - query embedding (or NULL)
|
||||||
param_index = 2 # Start at $2 for next parameter
|
param_index = 2 # Start at $2 for next parameter
|
||||||
|
|
||||||
# Always filter for available agents
|
# Always filter for available agents
|
||||||
where_parts.append("is_available = true")
|
where_parts.append("is_available = true")
|
||||||
|
|
||||||
|
# Require search signals to be present
|
||||||
|
if embedding_str is None:
|
||||||
|
# No embedding available - require BM25 search only
|
||||||
|
where_parts.append("search IS NOT NULL")
|
||||||
|
elif filter_mode == "strict":
|
||||||
|
# Strict mode: require both embedding AND search to be available
|
||||||
|
where_parts.append("embedding IS NOT NULL")
|
||||||
|
where_parts.append("search IS NOT NULL")
|
||||||
|
else:
|
||||||
|
# Permissive/combined: require at least one signal
|
||||||
|
where_parts.append("(embedding IS NOT NULL OR search IS NOT NULL)")
|
||||||
|
|
||||||
if featured:
|
if featured:
|
||||||
where_parts.append("featured = true")
|
where_parts.append("featured = true")
|
||||||
|
|
||||||
if creators and creators:
|
if creators:
|
||||||
# Use ANY with array parameter
|
# Use ANY with array parameter
|
||||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||||
params.append(creators)
|
params.append(creators)
|
||||||
param_index += 1
|
param_index += 1
|
||||||
|
|
||||||
if category and category:
|
if category:
|
||||||
where_parts.append(f"${param_index} = ANY(categories)")
|
where_parts.append(f"${param_index} = ANY(categories)")
|
||||||
params.append(category)
|
params.append(category)
|
||||||
param_index += 1
|
param_index += 1
|
||||||
|
|
||||||
|
# Add search query for BM25
|
||||||
|
params.append(search_query)
|
||||||
|
bm25_query_param = f"${param_index}"
|
||||||
|
param_index += 1
|
||||||
|
|
||||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||||
|
|
||||||
|
# Build score filter based on filter_mode
|
||||||
|
# This filter is applied BEFORE RRF ranking in the filtered_agents CTE
|
||||||
|
if embedding_str is None:
|
||||||
|
# No embedding - filter only on BM25 score
|
||||||
|
score_filter = f"bm25_score >= {BM25_RELEVANCE_THRESHOLD}"
|
||||||
|
elif filter_mode == "strict":
|
||||||
|
score_filter = f"""
|
||||||
|
bm25_score >= {BM25_RELEVANCE_THRESHOLD}
|
||||||
|
AND vector_score >= {VECTOR_SEARCH_SIMILARITY_THRESHOLD}
|
||||||
|
"""
|
||||||
|
elif filter_mode == "permissive":
|
||||||
|
score_filter = f"""
|
||||||
|
bm25_score >= {BM25_RELEVANCE_THRESHOLD}
|
||||||
|
OR vector_score >= {VECTOR_SEARCH_SIMILARITY_THRESHOLD}
|
||||||
|
"""
|
||||||
|
else: # combined - no pre-filtering on individual scores
|
||||||
|
score_filter = "1=1"
|
||||||
|
|
||||||
|
# RRF score filter is applied AFTER ranking to filter irrelevant results
|
||||||
|
rrf_score_filter = f"rrf_score >= {RRF_SCORE_THRESHOLD}"
|
||||||
|
|
||||||
|
# Build ORDER BY clause - sorted_by takes precedence, rrf_score as secondary
|
||||||
|
if sorted_by == "rating":
|
||||||
|
order_by_clause = "rating DESC, rrf_score DESC"
|
||||||
|
elif sorted_by == "runs":
|
||||||
|
order_by_clause = "runs DESC, rrf_score DESC"
|
||||||
|
elif sorted_by == "name":
|
||||||
|
order_by_clause = "agent_name ASC, rrf_score DESC"
|
||||||
|
elif sorted_by == "updated_at":
|
||||||
|
order_by_clause = "updated_at DESC, rrf_score DESC"
|
||||||
|
else:
|
||||||
|
# Default: order by RRF relevance score
|
||||||
|
order_by_clause = "rrf_score DESC, updated_at DESC"
|
||||||
|
|
||||||
# Add pagination params
|
# Add pagination params
|
||||||
params.extend([page_size, offset])
|
params.extend([page_size, offset])
|
||||||
limit_param = f"${param_index}"
|
limit_param = f"${param_index}"
|
||||||
offset_param = f"${param_index + 1}"
|
offset_param = f"${param_index + 1}"
|
||||||
|
|
||||||
# Execute full-text search query with parameterized values
|
# Hybrid search SQL with Reciprocal Rank Fusion (RRF)
|
||||||
|
# CTEs: scored_agents -> filtered_agents -> ranked_agents -> rrf_scored
|
||||||
sql_query = f"""
|
sql_query = f"""
|
||||||
|
WITH scored_agents AS (
|
||||||
|
SELECT
|
||||||
|
slug,
|
||||||
|
agent_name,
|
||||||
|
agent_image,
|
||||||
|
creator_username,
|
||||||
|
creator_avatar,
|
||||||
|
sub_heading,
|
||||||
|
description,
|
||||||
|
runs,
|
||||||
|
rating,
|
||||||
|
categories,
|
||||||
|
featured,
|
||||||
|
is_available,
|
||||||
|
updated_at,
|
||||||
|
-- BM25 score using ts_rank_cd (covers density normalization)
|
||||||
|
COALESCE(
|
||||||
|
ts_rank_cd(
|
||||||
|
search,
|
||||||
|
plainto_tsquery('english', {bm25_query_param}),
|
||||||
|
32 -- normalization: divide by document length
|
||||||
|
),
|
||||||
|
0
|
||||||
|
) AS bm25_score,
|
||||||
|
-- Vector similarity score (cosine: 1 - distance)
|
||||||
|
-- Returns 0 when query embedding ($1) is NULL (no OpenAI key)
|
||||||
|
CASE
|
||||||
|
WHEN $1 IS NOT NULL AND embedding IS NOT NULL
|
||||||
|
THEN 1 - (embedding <=> $1::vector)
|
||||||
|
ELSE 0
|
||||||
|
END AS vector_score,
|
||||||
|
-- Popularity score (log-normalized run count)
|
||||||
|
CASE
|
||||||
|
WHEN runs > 0
|
||||||
|
THEN LN(runs + 1)
|
||||||
|
ELSE 0
|
||||||
|
END AS popularity_score
|
||||||
|
FROM {{schema_prefix}}"StoreAgent"
|
||||||
|
WHERE {sql_where_clause}
|
||||||
|
),
|
||||||
|
max_popularity AS (
|
||||||
|
SELECT GREATEST(MAX(popularity_score), 1) AS max_pop
|
||||||
|
FROM scored_agents
|
||||||
|
),
|
||||||
|
normalized_agents AS (
|
||||||
|
SELECT
|
||||||
|
sa.*,
|
||||||
|
-- Normalize popularity to [0, 1] range
|
||||||
|
sa.popularity_score / mp.max_pop AS norm_popularity_score
|
||||||
|
FROM scored_agents sa
|
||||||
|
CROSS JOIN max_popularity mp
|
||||||
|
),
|
||||||
|
filtered_agents AS (
|
||||||
|
SELECT *
|
||||||
|
FROM normalized_agents
|
||||||
|
WHERE {score_filter}
|
||||||
|
),
|
||||||
|
ranked_agents AS (
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY bm25_score DESC NULLS LAST) AS bm25_rank,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY vector_score DESC NULLS LAST) AS vector_rank,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY norm_popularity_score DESC NULLS LAST) AS popularity_rank
|
||||||
|
FROM filtered_agents
|
||||||
|
),
|
||||||
|
rrf_scored AS (
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
-- RRF formula with weighted contributions
|
||||||
|
-- BM25 and vector get full weight, popularity gets 0.5x weight
|
||||||
|
(1.0 / ({RRF_K} + bm25_rank)) +
|
||||||
|
(1.0 / ({RRF_K} + vector_rank)) +
|
||||||
|
(0.5 / ({RRF_K} + popularity_rank)) AS rrf_score
|
||||||
|
FROM ranked_agents
|
||||||
|
)
|
||||||
SELECT
|
SELECT
|
||||||
slug,
|
slug,
|
||||||
agent_name,
|
agent_name,
|
||||||
@@ -119,25 +298,79 @@ async def get_store_agents(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
ts_rank_cd(search, query) AS rank
|
rrf_score
|
||||||
FROM {{schema_prefix}}"StoreAgent",
|
FROM rrf_scored
|
||||||
plainto_tsquery('english', $1) AS query
|
WHERE {rrf_score_filter}
|
||||||
WHERE {sql_where_clause}
|
|
||||||
AND search @@ query
|
|
||||||
ORDER BY {order_by_clause}
|
ORDER BY {order_by_clause}
|
||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Count query for pagination - only uses search term parameter
|
# Count query (without pagination) - requires same CTE structure because:
|
||||||
|
# 1. RRF scoring requires computing ranks across ALL matching results
|
||||||
|
# 2. The rrf_score_filter threshold must be applied consistently
|
||||||
|
# Note: This is inherent to RRF - there's no way to count without ranking
|
||||||
count_query = f"""
|
count_query = f"""
|
||||||
|
WITH scored_agents AS (
|
||||||
|
SELECT
|
||||||
|
runs,
|
||||||
|
COALESCE(
|
||||||
|
ts_rank_cd(
|
||||||
|
search,
|
||||||
|
plainto_tsquery('english', {bm25_query_param}),
|
||||||
|
32
|
||||||
|
),
|
||||||
|
0
|
||||||
|
) AS bm25_score,
|
||||||
|
CASE
|
||||||
|
WHEN $1 IS NOT NULL AND embedding IS NOT NULL
|
||||||
|
THEN 1 - (embedding <=> $1::vector)
|
||||||
|
ELSE 0
|
||||||
|
END AS vector_score,
|
||||||
|
CASE
|
||||||
|
WHEN runs > 0
|
||||||
|
THEN LN(runs + 1)
|
||||||
|
ELSE 0
|
||||||
|
END AS popularity_score
|
||||||
|
FROM {{schema_prefix}}"StoreAgent"
|
||||||
|
WHERE {sql_where_clause}
|
||||||
|
),
|
||||||
|
max_popularity AS (
|
||||||
|
SELECT GREATEST(MAX(popularity_score), 1) AS max_pop
|
||||||
|
FROM scored_agents
|
||||||
|
),
|
||||||
|
normalized_agents AS (
|
||||||
|
SELECT
|
||||||
|
sa.*,
|
||||||
|
sa.popularity_score / mp.max_pop AS norm_popularity_score
|
||||||
|
FROM scored_agents sa
|
||||||
|
CROSS JOIN max_popularity mp
|
||||||
|
),
|
||||||
|
filtered_agents AS (
|
||||||
|
SELECT *
|
||||||
|
FROM normalized_agents
|
||||||
|
WHERE {score_filter}
|
||||||
|
),
|
||||||
|
ranked_agents AS (
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY bm25_score DESC NULLS LAST) AS bm25_rank,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY vector_score DESC NULLS LAST) AS vector_rank,
|
||||||
|
ROW_NUMBER() OVER (ORDER BY norm_popularity_score DESC NULLS LAST) AS popularity_rank
|
||||||
|
FROM filtered_agents
|
||||||
|
),
|
||||||
|
rrf_scored AS (
|
||||||
|
SELECT
|
||||||
|
(1.0 / ({RRF_K} + bm25_rank)) +
|
||||||
|
(1.0 / ({RRF_K} + vector_rank)) +
|
||||||
|
(0.5 / ({RRF_K} + popularity_rank)) AS rrf_score
|
||||||
|
FROM ranked_agents
|
||||||
|
)
|
||||||
SELECT COUNT(*) as count
|
SELECT COUNT(*) as count
|
||||||
FROM {{schema_prefix}}"StoreAgent",
|
FROM rrf_scored
|
||||||
plainto_tsquery('english', $1) AS query
|
WHERE {rrf_score_filter}
|
||||||
WHERE {sql_where_clause}
|
|
||||||
AND search @@ query
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Execute both queries with parameters
|
# Execute queries
|
||||||
agents = await query_raw_with_schema(sql_query, *params)
|
agents = await query_raw_with_schema(sql_query, *params)
|
||||||
|
|
||||||
# For count, use params without pagination (last 2 params)
|
# For count, use params without pagination (last 2 params)
|
||||||
@@ -255,6 +488,56 @@ async def log_search_term(search_query: str):
|
|||||||
logger.error(f"Error logging search term: {e}")
|
logger.error(f"Error logging search term: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_and_store_embedding(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
name: str,
|
||||||
|
sub_heading: str,
|
||||||
|
description: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate and store embedding for a store listing version.
|
||||||
|
|
||||||
|
This creates a vector embedding from the agent's name, sub_heading, and
|
||||||
|
description, which is used for semantic search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_listing_version_id: The ID of the store listing version.
|
||||||
|
name: The agent name.
|
||||||
|
sub_heading: The agent sub-heading/tagline.
|
||||||
|
description: The agent description.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
embedding_service = get_embedding_service()
|
||||||
|
search_text = create_search_text(name, sub_heading, description)
|
||||||
|
|
||||||
|
if not search_text:
|
||||||
|
logger.warning(
|
||||||
|
f"No searchable text for version {store_listing_version_id}, "
|
||||||
|
"skipping embedding generation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
embedding = await embedding_service.generate_embedding(search_text)
|
||||||
|
embedding_str = "[" + ",".join(map(str, embedding)) + "]"
|
||||||
|
|
||||||
|
await query_raw_with_schema(
|
||||||
|
"""
|
||||||
|
UPDATE {schema_prefix}"StoreListingVersion"
|
||||||
|
SET embedding = $1::vector
|
||||||
|
WHERE id = $2
|
||||||
|
""",
|
||||||
|
embedding_str,
|
||||||
|
store_listing_version_id,
|
||||||
|
)
|
||||||
|
logger.debug(f"Generated embedding for version {store_listing_version_id}")
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the whole operation
|
||||||
|
# Embeddings can be generated later via backfill
|
||||||
|
logger.error(
|
||||||
|
f"Failed to generate embedding for {store_listing_version_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_store_agent_details(
|
async def get_store_agent_details(
|
||||||
username: str, agent_name: str
|
username: str, agent_name: str
|
||||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||||
@@ -805,6 +1088,12 @@ async def create_store_submission(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate embedding for semantic search
|
||||||
|
if store_listing_version_id:
|
||||||
|
await _generate_and_store_embedding(
|
||||||
|
store_listing_version_id, name, sub_heading, description
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Created store listing for agent {agent_id}")
|
logger.debug(f"Created store listing for agent {agent_id}")
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return backend.server.v2.store.model.StoreSubmission(
|
||||||
@@ -970,6 +1259,12 @@ async def edit_store_submission(
|
|||||||
|
|
||||||
if not updated_version:
|
if not updated_version:
|
||||||
raise DatabaseError("Failed to update store listing version")
|
raise DatabaseError("Failed to update store listing version")
|
||||||
|
|
||||||
|
# Regenerate embedding with updated content
|
||||||
|
await _generate_and_store_embedding(
|
||||||
|
store_listing_version_id, name, sub_heading, description
|
||||||
|
)
|
||||||
|
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return backend.server.v2.store.model.StoreSubmission(
|
||||||
agent_id=current_version.agentGraphId,
|
agent_id=current_version.agentGraphId,
|
||||||
agent_version=current_version.agentGraphVersion,
|
agent_version=current_version.agentGraphVersion,
|
||||||
@@ -1102,6 +1397,12 @@ async def create_store_version(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
f"Created new version for listing {store_listing_id} of agent {agent_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate embedding for semantic search
|
||||||
|
await _generate_and_store_embedding(
|
||||||
|
new_version.id, name, sub_heading, description
|
||||||
|
)
|
||||||
|
|
||||||
# Return submission details
|
# Return submission details
|
||||||
return backend.server.v2.store.model.StoreSubmission(
|
return backend.server.v2.store.model.StoreSubmission(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
|||||||
@@ -405,3 +405,347 @@ async def test_get_store_agents_search_category_array_injection():
|
|||||||
# Verify the query executed without error
|
# Verify the query executed without error
|
||||||
# Category should be parameterized, preventing SQL injection
|
# Category should be parameterized, preventing SQL injection
|
||||||
assert isinstance(result.agents, list)
|
assert isinstance(result.agents, list)
|
||||||
|
|
||||||
|
|
||||||
|
# Hybrid search tests (BM25 + vector + popularity with RRF ranking)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_mocked(mocker):
|
||||||
|
"""Test hybrid search uses embedding service and executes query safely."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema to return empty results
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with search query
|
||||||
|
result = await db.get_store_agents(search_query="test query")
|
||||||
|
|
||||||
|
# Verify embedding service was called
|
||||||
|
mock_embedding_service.generate_embedding.assert_called_once_with("test query")
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert isinstance(result.agents, list)
|
||||||
|
assert len(result.agents) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_with_results(mocker):
|
||||||
|
"""Test hybrid search returns properly formatted results with RRF scoring."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query results (hybrid search returns rrf_score instead of similarity)
|
||||||
|
mock_agents = [
|
||||||
|
{
|
||||||
|
"slug": "test-agent",
|
||||||
|
"agent_name": "Test Agent",
|
||||||
|
"agent_image": ["image.jpg"],
|
||||||
|
"creator_username": "creator",
|
||||||
|
"creator_avatar": "avatar.jpg",
|
||||||
|
"sub_heading": "Test heading",
|
||||||
|
"description": "Test description",
|
||||||
|
"runs": 10,
|
||||||
|
"rating": 4.5,
|
||||||
|
"categories": ["test"],
|
||||||
|
"featured": False,
|
||||||
|
"is_available": True,
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"rrf_score": 0.048, # RRF score from combined rankings
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_count = [{"count": 1}]
|
||||||
|
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[mock_agents, mock_count]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with search query
|
||||||
|
result = await db.get_store_agents(search_query="test query")
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(result.agents) == 1
|
||||||
|
assert result.agents[0].slug == "test-agent"
|
||||||
|
assert result.agents[0].agent_name == "Test Agent"
|
||||||
|
assert result.pagination.total_items == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_with_filters(mocker):
|
||||||
|
"""Test hybrid search works correctly with additional filters."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with search query and filters
|
||||||
|
await db.get_store_agents(
|
||||||
|
search_query="test query",
|
||||||
|
featured=True,
|
||||||
|
creators=["creator1", "creator2"],
|
||||||
|
category="AI",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify query was called with parameterized values
|
||||||
|
# First call is the main query, second is count
|
||||||
|
assert mock_query.call_count == 2
|
||||||
|
|
||||||
|
# Check that the SQL query includes proper parameterization
|
||||||
|
first_call_args = mock_query.call_args_list[0]
|
||||||
|
sql_query = first_call_args[0][0]
|
||||||
|
|
||||||
|
# Verify key elements of hybrid search query
|
||||||
|
assert "embedding <=> $1::vector" in sql_query # Vector search
|
||||||
|
assert "ts_rank_cd" in sql_query # BM25 search
|
||||||
|
assert "rrf_score" in sql_query # RRF ranking
|
||||||
|
assert "featured = true" in sql_query
|
||||||
|
assert "creator_username = ANY($" in sql_query
|
||||||
|
assert "= ANY(categories)" in sql_query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_strict_filter_mode(mocker):
|
||||||
|
"""Test hybrid search with strict filter mode requires both BM25 and vector matches."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with strict filter mode
|
||||||
|
await db.get_store_agents(search_query="test query", filter_mode="strict")
|
||||||
|
|
||||||
|
# Check that the SQL query includes strict filtering conditions
|
||||||
|
first_call_args = mock_query.call_args_list[0]
|
||||||
|
sql_query = first_call_args[0][0]
|
||||||
|
|
||||||
|
# Strict mode requires both embedding AND search to be present
|
||||||
|
assert "embedding IS NOT NULL" in sql_query
|
||||||
|
assert "search IS NOT NULL" in sql_query
|
||||||
|
# Strict score filter requires both thresholds to be met
|
||||||
|
assert "bm25_score >=" in sql_query
|
||||||
|
assert "AND vector_score >=" in sql_query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_permissive_filter_mode(mocker):
|
||||||
|
"""Test hybrid search with permissive filter mode requires either BM25 or vector match."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with permissive filter mode
|
||||||
|
await db.get_store_agents(search_query="test query", filter_mode="permissive")
|
||||||
|
|
||||||
|
# Check that the SQL query includes permissive filtering conditions
|
||||||
|
first_call_args = mock_query.call_args_list[0]
|
||||||
|
sql_query = first_call_args[0][0]
|
||||||
|
|
||||||
|
# Permissive mode requires at least one signal
|
||||||
|
assert "(embedding IS NOT NULL OR search IS NOT NULL)" in sql_query
|
||||||
|
# Permissive score filter requires either threshold to be met
|
||||||
|
assert "OR vector_score >=" in sql_query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_get_store_agents_hybrid_search_combined_filter_mode(mocker):
|
||||||
|
"""Test hybrid search with combined filter mode (default) filters by RRF score."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(side_effect=[[], [{"count": 0}]]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with combined filter mode (default)
|
||||||
|
await db.get_store_agents(search_query="test query", filter_mode="combined")
|
||||||
|
|
||||||
|
# Check that the SQL query includes combined filtering
|
||||||
|
first_call_args = mock_query.call_args_list[0]
|
||||||
|
sql_query = first_call_args[0][0]
|
||||||
|
|
||||||
|
# Combined mode requires at least one signal
|
||||||
|
assert "(embedding IS NOT NULL OR search IS NOT NULL)" in sql_query
|
||||||
|
# Combined mode uses "1=1" as pre-filter (no individual score filtering)
|
||||||
|
# But applies RRF score threshold to filter irrelevant results
|
||||||
|
assert "rrf_score" in sql_query
|
||||||
|
assert "rrf_score >=" in sql_query # RRF threshold filter applied
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_and_store_embedding_success(mocker):
|
||||||
|
"""Test that embedding generation and storage works correctly."""
|
||||||
|
from backend.integrations.embeddings import EMBEDDING_DIMENSIONS
|
||||||
|
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding = [0.1] * EMBEDDING_DIMENSIONS
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
return_value=mock_embedding
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(return_value=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the internal function
|
||||||
|
await db._generate_and_store_embedding(
|
||||||
|
store_listing_version_id="version-123",
|
||||||
|
name="Test Agent",
|
||||||
|
sub_heading="A test agent",
|
||||||
|
description="Does testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embedding service was called with combined text
|
||||||
|
mock_embedding_service.generate_embedding.assert_called_once_with(
|
||||||
|
"Test Agent A test agent Does testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify database update was called
|
||||||
|
mock_query.assert_called_once()
|
||||||
|
call_args = mock_query.call_args
|
||||||
|
assert "UPDATE" in call_args[0][0]
|
||||||
|
assert "embedding = $1::vector" in call_args[0][0]
|
||||||
|
assert call_args[0][2] == "version-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_and_store_embedding_empty_text(mocker):
|
||||||
|
"""Test that embedding is not generated for empty text."""
|
||||||
|
# Mock embedding service
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock()
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock query_raw_with_schema
|
||||||
|
mock_query = mocker.patch(
|
||||||
|
"backend.server.v2.store.db.query_raw_with_schema",
|
||||||
|
mocker.AsyncMock(return_value=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call with empty fields
|
||||||
|
await db._generate_and_store_embedding(
|
||||||
|
store_listing_version_id="version-123",
|
||||||
|
name="",
|
||||||
|
sub_heading="",
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embedding service was NOT called
|
||||||
|
mock_embedding_service.generate_embedding.assert_not_called()
|
||||||
|
|
||||||
|
# Verify database was NOT updated
|
||||||
|
mock_query.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_generate_and_store_embedding_handles_error(mocker):
|
||||||
|
"""Test that embedding generation errors don't crash the operation."""
|
||||||
|
# Mock embedding service to raise an error
|
||||||
|
mock_embedding_service = mocker.MagicMock()
|
||||||
|
mock_embedding_service.generate_embedding = mocker.AsyncMock(
|
||||||
|
side_effect=Exception("API error")
|
||||||
|
)
|
||||||
|
mocker.patch(
|
||||||
|
"backend.server.v2.store.db.get_embedding_service",
|
||||||
|
mocker.MagicMock(return_value=mock_embedding_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call should not raise - errors are logged but not propagated
|
||||||
|
await db._generate_and_store_embedding(
|
||||||
|
store_listing_version_id="version-123",
|
||||||
|
name="Test Agent",
|
||||||
|
sub_heading="A test agent",
|
||||||
|
description="Does testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embedding service was called (and failed)
|
||||||
|
mock_embedding_service.generate_embedding.assert_called_once()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -7,6 +8,19 @@ import pydantic
|
|||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilterMode(str, Enum):
|
||||||
|
"""How to combine BM25 and vector search results for filtering.
|
||||||
|
|
||||||
|
- STRICT: Must pass BOTH BM25 AND vector similarity thresholds
|
||||||
|
- PERMISSIVE: Must pass EITHER BM25 OR vector similarity threshold
|
||||||
|
- COMBINED: No pre-filtering, only the combined RRF score matters (default)
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRICT = "strict"
|
||||||
|
PERMISSIVE = "permissive"
|
||||||
|
COMBINED = "combined"
|
||||||
|
|
||||||
|
|
||||||
class MyAgent(pydantic.BaseModel):
|
class MyAgent(pydantic.BaseModel):
|
||||||
agent_id: str
|
agent_id: str
|
||||||
agent_version: int
|
agent_version: int
|
||||||
|
|||||||
@@ -99,18 +99,30 @@ async def get_agents(
|
|||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
|
filter_mode: Literal["strict", "permissive", "combined"] = "permissive",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||||
|
|
||||||
|
When search_query is provided, uses hybrid search combining:
|
||||||
|
- BM25 full-text search (lexical matching)
|
||||||
|
- Vector semantic similarity (meaning-based matching)
|
||||||
|
- Popularity signal (run counts)
|
||||||
|
|
||||||
|
Results are ranked using Reciprocal Rank Fusion (RRF).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
featured (bool, optional): Filter to only show featured agents. Defaults to False.
|
||||||
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
creator (str | None, optional): Filter agents by creator username. Defaults to None.
|
||||||
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
sorted_by (str | None, optional): Sort agents by "runs" or "rating". Defaults to None.
|
||||||
search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.
|
search_query (str | None, optional): Search agents by name, subheading and description.
|
||||||
category (str | None, optional): Filter agents by category. Defaults to None.
|
category (str | None, optional): Filter agents by category. Defaults to None.
|
||||||
page (int, optional): Page number for pagination. Defaults to 1.
|
page (int, optional): Page number for pagination. Defaults to 1.
|
||||||
page_size (int, optional): Number of agents per page. Defaults to 20.
|
page_size (int, optional): Number of agents per page. Defaults to 20.
|
||||||
|
filter_mode (str, optional): Controls result filtering when searching:
|
||||||
|
- "strict": Must match BOTH BM25 AND vector thresholds
|
||||||
|
- "permissive": Must match EITHER BM25 OR vector threshold
|
||||||
|
- "combined": No threshold filtering, rely on RRF score (default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StoreAgentsResponse: Paginated list of agents matching the filters
|
StoreAgentsResponse: Paginated list of agents matching the filters
|
||||||
@@ -144,6 +156,7 @@ async def get_agents(
|
|||||||
category=category,
|
category=category,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
filter_mode=filter_mode,
|
||||||
)
|
)
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def test_get_agents_defaults(
|
|||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -112,6 +113,7 @@ def test_get_agents_featured(
|
|||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -159,6 +161,7 @@ def test_get_agents_by_creator(
|
|||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -206,6 +209,7 @@ def test_get_agents_sorted(
|
|||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -253,6 +257,7 @@ def test_get_agents_search(
|
|||||||
category=None,
|
category=None,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -299,6 +304,7 @@ def test_get_agents_category(
|
|||||||
category="test-category",
|
category="test-category",
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -348,6 +354,7 @@ def test_get_agents_pagination(
|
|||||||
category=None,
|
category=None,
|
||||||
page=2,
|
page=2,
|
||||||
page_size=5,
|
page_size=5,
|
||||||
|
filter_mode="permissive",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
-- Migration: Replace full-text search with pgvector-based vector search
|
||||||
|
-- This migration:
|
||||||
|
-- 1. Enables the pgvector extension
|
||||||
|
-- 2. Drops the StoreAgent view (depends on search column)
|
||||||
|
-- 3. Removes the full-text search infrastructure (trigger, function, tsvector column)
|
||||||
|
-- 4. Adds a vector embedding column for semantic search
|
||||||
|
-- 5. Creates an index for fast vector similarity search
|
||||||
|
-- 6. Recreates the StoreAgent view with the embedding column
|
||||||
|
|
||||||
|
-- Enable pgvector extension
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
|
||||||
|
-- First drop the view that depends on the search column
|
||||||
|
DROP VIEW IF EXISTS "StoreAgent";
|
||||||
|
|
||||||
|
-- Add embedding column for vector search (1536 dimensions for text-embedding-3-small)
|
||||||
|
ALTER TABLE "StoreListingVersion"
|
||||||
|
ADD COLUMN IF NOT EXISTS "embedding" vector(1536);
|
||||||
|
|
||||||
|
-- Create IVFFlat index for fast similarity search
|
||||||
|
-- Using cosine distance (vector_cosine_ops) which is standard for text embeddings
|
||||||
|
-- lists = 100 is appropriate for datasets under 1M rows
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_store_listing_version_embedding
|
||||||
|
ON "StoreListingVersion"
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
|
||||||
|
-- Recreate StoreAgent view WITHOUT search column, WITH embedding column
|
||||||
|
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||||
|
WITH latest_versions AS (
|
||||||
|
SELECT
|
||||||
|
"storeListingId",
|
||||||
|
MAX(version) AS max_version
|
||||||
|
FROM "StoreListingVersion"
|
||||||
|
WHERE "submissionStatus" = 'APPROVED'
|
||||||
|
GROUP BY "storeListingId"
|
||||||
|
),
|
||||||
|
agent_versions AS (
|
||||||
|
SELECT
|
||||||
|
"storeListingId",
|
||||||
|
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||||
|
FROM "StoreListingVersion"
|
||||||
|
WHERE "submissionStatus" = 'APPROVED'
|
||||||
|
GROUP BY "storeListingId"
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
sl.id AS listing_id,
|
||||||
|
slv.id AS "storeListingVersionId",
|
||||||
|
slv."createdAt" AS updated_at,
|
||||||
|
sl.slug,
|
||||||
|
COALESCE(slv.name, '') AS agent_name,
|
||||||
|
slv."videoUrl" AS agent_video,
|
||||||
|
slv."agentOutputDemoUrl" AS agent_output_demo,
|
||||||
|
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||||
|
slv."isFeatured" AS featured,
|
||||||
|
p.username AS creator_username,
|
||||||
|
p."avatarUrl" AS creator_avatar,
|
||||||
|
slv."subHeading" AS sub_heading,
|
||||||
|
slv.description,
|
||||||
|
slv.categories,
|
||||||
|
slv.search,
|
||||||
|
slv.embedding,
|
||||||
|
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||||
|
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||||
|
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||||
|
COALESCE(sl."useForOnboarding", false) AS "useForOnboarding",
|
||||||
|
slv."isAvailable" AS is_available
|
||||||
|
FROM "StoreListing" sl
|
||||||
|
JOIN latest_versions lv
|
||||||
|
ON sl.id = lv."storeListingId"
|
||||||
|
JOIN "StoreListingVersion" slv
|
||||||
|
ON slv."storeListingId" = lv."storeListingId"
|
||||||
|
AND slv.version = lv.max_version
|
||||||
|
AND slv."submissionStatus" = 'APPROVED'
|
||||||
|
JOIN "AgentGraph" a
|
||||||
|
ON slv."agentGraphId" = a.id
|
||||||
|
AND slv."agentGraphVersion" = a.version
|
||||||
|
LEFT JOIN "Profile" p
|
||||||
|
ON sl."owningUserId" = p."userId"
|
||||||
|
LEFT JOIN "mv_review_stats" rs
|
||||||
|
ON sl.id = rs."storeListingId"
|
||||||
|
LEFT JOIN "mv_agent_run_counts" ar
|
||||||
|
ON a.id = ar."agentGraphId"
|
||||||
|
LEFT JOIN agent_versions av
|
||||||
|
ON sl.id = av."storeListingId"
|
||||||
|
WHERE sl."isDeleted" = false
|
||||||
|
AND sl."hasApprovedVersion" = true;
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
-- Migration: Add hybrid search infrastructure (BM25 + vector + popularity)
|
||||||
|
-- This migration:
|
||||||
|
-- 1. Creates/updates the tsvector trigger with weighted fields
|
||||||
|
-- 2. Adds GIN index for full-text search performance
|
||||||
|
-- 3. Backfills existing records with tsvector data
|
||||||
|
|
||||||
|
-- Create or replace the trigger function with WEIGHTED tsvector
|
||||||
|
-- Weight A = name (highest priority), B = subHeading, C = description
|
||||||
|
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.search := setweight(to_tsvector('english', COALESCE(NEW.name, '')), 'A') ||
|
||||||
|
setweight(to_tsvector('english', COALESCE(NEW."subHeading", '')), 'B') ||
|
||||||
|
setweight(to_tsvector('english', COALESCE(NEW.description, '')), 'C');
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
-- Drop and recreate trigger to ensure it's active with the updated function
|
||||||
|
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||||
|
CREATE TRIGGER "update_tsvector"
|
||||||
|
BEFORE INSERT OR UPDATE OF name, "subHeading", description ON "StoreListingVersion"
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION update_tsvector_column();
|
||||||
|
|
||||||
|
-- Create GIN index for full-text search performance
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_store_listing_version_search_gin
|
||||||
|
ON "StoreListingVersion" USING GIN (search);
|
||||||
|
|
||||||
|
-- Backfill existing records with weighted tsvector
|
||||||
|
UPDATE "StoreListingVersion"
|
||||||
|
SET search = setweight(to_tsvector('english', COALESCE(name, '')), 'A') ||
|
||||||
|
setweight(to_tsvector('english', COALESCE("subHeading", '')), 'B') ||
|
||||||
|
setweight(to_tsvector('english', COALESCE(description, '')), 'C')
|
||||||
|
WHERE search IS NULL
|
||||||
|
OR search = ''::tsvector;
|
||||||
@@ -138,3 +138,4 @@ filterwarnings = [
|
|||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
|
||||||
|
|||||||
@@ -727,7 +727,7 @@ view StoreAgent {
|
|||||||
sub_heading String
|
sub_heading String
|
||||||
description String
|
description String
|
||||||
categories String[]
|
categories String[]
|
||||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
embedding Unsupported("vector(1536)")?
|
||||||
runs Int
|
runs Int
|
||||||
rating Float
|
rating Float
|
||||||
versions String[]
|
versions String[]
|
||||||
@@ -863,7 +863,11 @@ model StoreListingVersion {
|
|||||||
// Old versions can be made unavailable by the author if desired
|
// Old versions can be made unavailable by the author if desired
|
||||||
isAvailable Boolean @default(true)
|
isAvailable Boolean @default(true)
|
||||||
|
|
||||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
// Full-text search tsvector column
|
||||||
|
search Unsupported("tsvector")?
|
||||||
|
|
||||||
|
// Vector embedding for semantic search
|
||||||
|
embedding Unsupported("vector(1536)")?
|
||||||
|
|
||||||
// Version workflow state
|
// Version workflow state
|
||||||
submissionStatus SubmissionStatus @default(DRAFT)
|
submissionStatus SubmissionStatus @default(DRAFT)
|
||||||
|
|||||||
@@ -158,3 +158,4 @@
|
|||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2690,7 +2690,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "store", "public"],
|
"tags": ["v2", "store", "public"],
|
||||||
"summary": "List store agents",
|
"summary": "List store agents",
|
||||||
"description": "Get a paginated list of agents from the store with optional filtering and sorting.\n\nArgs:\n featured (bool, optional): Filter to only show featured agents. Defaults to False.\n creator (str | None, optional): Filter agents by creator username. Defaults to None.\n sorted_by (str | None, optional): Sort agents by \"runs\" or \"rating\". Defaults to None.\n search_query (str | None, optional): Search agents by name, subheading and description. Defaults to None.\n category (str | None, optional): Filter agents by category. Defaults to None.\n page (int, optional): Page number for pagination. Defaults to 1.\n page_size (int, optional): Number of agents per page. Defaults to 20.\n\nReturns:\n StoreAgentsResponse: Paginated list of agents matching the filters\n\nRaises:\n HTTPException: If page or page_size are less than 1\n\nUsed for:\n- Home Page Featured Agents\n- Home Page Top Agents\n- Search Results\n- Agent Details - Other Agents By Creator\n- Agent Details - Similar Agents\n- Creator Details - Agents By Creator",
|
"description": "Get a paginated list of agents from the store with optional filtering and sorting.\n\nWhen search_query is provided, uses hybrid search combining:\n- BM25 full-text search (lexical matching)\n- Vector semantic similarity (meaning-based matching)\n- Popularity signal (run counts)\n\nResults are ranked using Reciprocal Rank Fusion (RRF).\n\nArgs:\n featured (bool, optional): Filter to only show featured agents. Defaults to False.\n creator (str | None, optional): Filter agents by creator username. Defaults to None.\n sorted_by (str | None, optional): Sort agents by \"runs\" or \"rating\". Defaults to None.\n search_query (str | None, optional): Search agents by name, subheading and description.\n category (str | None, optional): Filter agents by category. Defaults to None.\n page (int, optional): Page number for pagination. Defaults to 1.\n page_size (int, optional): Number of agents per page. Defaults to 20.\n filter_mode (str, optional): Controls result filtering when searching:\n - \"strict\": Must match BOTH BM25 AND vector thresholds\n - \"permissive\": Must match EITHER BM25 OR vector threshold\n - \"combined\": No threshold filtering, rely on RRF score (default)\n\nReturns:\n StoreAgentsResponse: Paginated list of agents matching the filters\n\nRaises:\n HTTPException: If page or page_size are less than 1\n\nUsed for:\n- Home Page Featured Agents\n- Home Page Top Agents\n- Search Results\n- Agent Details - Other Agents By Creator\n- Agent Details - Similar Agents\n- Creator Details - Agents By Creator",
|
||||||
"operationId": "getV2List store agents",
|
"operationId": "getV2List store agents",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
@@ -2756,6 +2756,17 @@
|
|||||||
"in": "query",
|
"in": "query",
|
||||||
"required": false,
|
"required": false,
|
||||||
"schema": { "type": "integer", "default": 20, "title": "Page Size" }
|
"schema": { "type": "integer", "default": 20, "title": "Page Size" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "filter_mode",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"enum": ["strict", "permissive", "combined"],
|
||||||
|
"type": "string",
|
||||||
|
"default": "permissive",
|
||||||
|
"title": "Filter Mode"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
|
|||||||
Reference in New Issue
Block a user