mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 11:24:57 -05:00
Compare commits
17 Commits
make-old-w
...
otto/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e604528ea | ||
|
|
c3ec7c2880 | ||
|
|
7d9380a793 | ||
|
|
678ddde751 | ||
|
|
aef6f57cfd | ||
|
|
14cee1670a | ||
|
|
d81d1ce024 | ||
|
|
2dd341c369 | ||
|
|
1081590384 | ||
|
|
7e37de8e30 | ||
|
|
7ee94d986c | ||
|
|
18a1661fa3 | ||
|
|
b72521daa9 | ||
|
|
350ad3591b | ||
|
|
de0ec3d388 | ||
|
|
7cb1e588b0 | ||
|
|
582c6cad36 |
73
.github/workflows/classic-autogpt-ci.yml
vendored
73
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,15 +6,11 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -23,22 +19,47 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/original_autogpt
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -50,23 +71,41 @@ jobs:
|
|||||||
git config --global user.name "Auto-GPT-Bot"
|
git config --global user.name "Auto-GPT-Bot"
|
||||||
git config --global user.email "github-bot@agpt.co"
|
git config --global user.email "github-bot@agpt.co"
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- id: get_date
|
- id: get_date
|
||||||
name: Get date
|
name: Get date
|
||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -77,12 +116,12 @@ jobs:
|
|||||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--numprocesses=logical --durations=10 \
|
--numprocesses=logical --durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
tests/unit tests/integration
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -96,11 +135,11 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: autogpt-agent
|
flags: autogpt-agent,${{ runner.os }}
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/original_autogpt/logs/
|
||||||
|
|||||||
36
.github/workflows/classic-autogpts-ci.yml
vendored
36
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,6 +11,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
@@ -19,6 +22,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
@@ -29,9 +35,13 @@ defaults:
|
|||||||
jobs:
|
jobs:
|
||||||
serve-agent-protocol:
|
serve-agent-protocol:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
agent-name: [ original_autogpt ]
|
||||||
|
fail-fast: false
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -45,22 +55,22 @@ jobs:
|
|||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
|
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Run regression tests
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run smoke tests with direct-benchmark
|
|
||||||
run: |
|
run: |
|
||||||
poetry run direct-benchmark run \
|
./run agent start ${{ matrix.agent-name }}
|
||||||
--strategies one_shot \
|
cd ${{ matrix.agent-name }}
|
||||||
--models claude \
|
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||||
--tests ReadFile,WriteFile \
|
poetry run agbenchmark --test=WriteFile
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
AGENT_NAME: ${{ matrix.agent-name }}
|
||||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||||
NONINTERACTIVE_MODE: "true"
|
HELICONE_CACHE_ENABLED: false
|
||||||
CI: true
|
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||||
|
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||||
|
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
194
.github/workflows/classic-benchmark-ci.yml
vendored
194
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,21 +1,17 @@
|
|||||||
name: Classic - Direct Benchmark CI
|
name: Classic - AGBenchmark CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master, dev, ci-test* ]
|
branches: [ master, dev, ci-test* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -27,16 +23,23 @@ defaults:
|
|||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
benchmark-tests:
|
test:
|
||||||
runs-on: ubuntu-latest
|
permissions:
|
||||||
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/benchmark
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -44,88 +47,71 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python ${{ env.min-python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
|
if: runner.os != 'Windows'
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
- name: Install dependencies
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
|
|
||||||
- name: Run basic benchmark tests
|
- name: Run pytest with coverage
|
||||||
run: |
|
run: |
|
||||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
poetry run pytest -vv \
|
||||||
poetry run direct-benchmark run \
|
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--fresh \
|
--durations=10 \
|
||||||
--strategies one_shot \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
--models claude \
|
tests
|
||||||
--tests ReadFile \
|
|
||||||
--json
|
|
||||||
|
|
||||||
echo "Testing WriteFile challenge..."
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--tests WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test category filtering
|
- name: Upload test results to Codecov
|
||||||
run: |
|
if: ${{ !cancelled() }} # Run even if tests fail
|
||||||
echo "Testing coding category..."
|
uses: codecov/test-results-action@v1
|
||||||
poetry run direct-benchmark run \
|
with:
|
||||||
--fresh \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--categories coding \
|
|
||||||
--tests ReadFile,WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test multiple strategies
|
- name: Upload coverage reports to Codecov
|
||||||
run: |
|
uses: codecov/codecov-action@v5
|
||||||
echo "Testing multiple strategies..."
|
with:
|
||||||
poetry run direct-benchmark run \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--fresh \
|
flags: agbenchmark,${{ runner.os }}
|
||||||
--strategies one_shot,plan_execute \
|
|
||||||
--models claude \
|
|
||||||
--tests ReadFile \
|
|
||||||
--parallel 2 \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
# Run regression tests on maintain challenges
|
self-test-with-agent:
|
||||||
regression-tests:
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
timeout-minutes: 45
|
strategy:
|
||||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
matrix:
|
||||||
defaults:
|
agent-name: [forge]
|
||||||
run:
|
fail-fast: false
|
||||||
shell: bash
|
timeout-minutes: 20
|
||||||
working-directory: classic
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -140,23 +126,51 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run regression tests
|
- name: Run regression tests
|
||||||
|
working-directory: classic
|
||||||
run: |
|
run: |
|
||||||
echo "Running regression tests (previously beaten challenges)..."
|
./run agent start ${{ matrix.agent-name }}
|
||||||
poetry run direct-benchmark run \
|
cd ${{ matrix.agent-name }}
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
set +e # Ignore non-zero exit codes and continue execution
|
||||||
--models claude \
|
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||||
--maintain \
|
poetry run agbenchmark --maintain --mock
|
||||||
--parallel 4 \
|
EXIT_CODE=$?
|
||||||
--json
|
set -e # Stop ignoring non-zero exit codes
|
||||||
|
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||||
|
if [ $EXIT_CODE -eq 5 ]; then
|
||||||
|
echo "regression_tests.json is empty."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock"
|
||||||
|
poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||||
|
poetry run agbenchmark --mock --category=data
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||||
|
poetry run agbenchmark --mock --category=coding
|
||||||
|
|
||||||
|
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||||
|
# poetry run agbenchmark --test=WriteFile
|
||||||
|
cd ../benchmark
|
||||||
|
poetry install
|
||||||
|
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||||
|
export BUILD_SKILL_TREE=true
|
||||||
|
|
||||||
|
# poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||||
|
# if [ ! -z "$CHANGED" ]; then
|
||||||
|
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||||
|
# echo "$CHANGED"
|
||||||
|
# exit 1
|
||||||
|
# else
|
||||||
|
# echo "No unstaged changes."
|
||||||
|
# fi
|
||||||
env:
|
env:
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
182
.github/workflows/classic-forge-ci.yml
vendored
182
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,11 +6,13 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -19,38 +21,115 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/forge
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Checkout cassettes
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
env:
|
||||||
|
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
cassette_base_branch="${PR_BASE}"
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
|
||||||
|
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||||
|
cassette_base_branch="master"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||||
|
git fetch origin $cassette_branch
|
||||||
|
git fetch origin $cassette_base_branch
|
||||||
|
|
||||||
|
git checkout $cassette_branch
|
||||||
|
|
||||||
|
# Pick non-conflicting cassette updates from the base branch
|
||||||
|
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||||
|
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||||
|
"synced to upstream branch '$cassette_base_branch'."
|
||||||
|
else
|
||||||
|
git checkout -b $cassette_branch
|
||||||
|
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||||
|
"Using cassettes from '$cassette_base_branch'."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -61,15 +140,12 @@ jobs:
|
|||||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--durations=10 \
|
--durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
forge/forge forge/tests
|
forge
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
# API keys - tests that need these will skip if not available
|
|
||||||
# Secrets are not available to fork PRs (GitHub security feature)
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -83,11 +159,85 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: forge
|
flags: forge,${{ runner.os }}
|
||||||
|
|
||||||
|
- id: setup_git_auth
|
||||||
|
name: Set up git token authentication
|
||||||
|
# Cassettes may be pushed even when tests fail
|
||||||
|
if: success() || failure()
|
||||||
|
run: |
|
||||||
|
config_key="http.${{ github.server_url }}/.extraheader"
|
||||||
|
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||||
|
else
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||||
|
fi
|
||||||
|
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- id: push_cassettes
|
||||||
|
name: Push updated cassettes
|
||||||
|
# For pull requests, push updated cassettes even when tests fail
|
||||||
|
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||||
|
env:
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||||
|
is_pull_request=true
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
else
|
||||||
|
cassette_branch="${{ github.ref_name }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
# Commit & push changes to cassettes if any
|
||||||
|
if ! git diff --quiet; then
|
||||||
|
git add .
|
||||||
|
git commit -m "Auto-update cassettes"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
if [ ! $is_pull_request ]; then
|
||||||
|
cd ../..
|
||||||
|
git add tests/vcr_cassettes
|
||||||
|
git commit -m "Update cassette submodule"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
fi
|
||||||
|
echo "updated=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "updated=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "No cassette changes to commit"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Post Set up git token auth
|
||||||
|
if: steps.setup_git_auth.outcome == 'success'
|
||||||
|
run: |
|
||||||
|
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
|
||||||
|
- name: Apply "behaviour change" label and comment on PR
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
run: |
|
||||||
|
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||||
|
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||||
|
REPO="${{ github.repository }}"
|
||||||
|
|
||||||
|
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||||
|
echo "Adding label and comment..."
|
||||||
|
echo $TOKEN | gh auth login --with-token
|
||||||
|
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||||
|
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/forge/logs/
|
||||||
|
|||||||
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
60
.github/workflows/classic-frontend-ci.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
name: Classic - Frontend CI/CD
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
- dev
|
||||||
|
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||||
|
paths:
|
||||||
|
- 'classic/frontend/**'
|
||||||
|
- '.github/workflows/classic-frontend-ci.yml'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'classic/frontend/**'
|
||||||
|
- '.github/workflows/classic-frontend-ci.yml'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repo
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Flutter
|
||||||
|
uses: subosito/flutter-action@v2
|
||||||
|
with:
|
||||||
|
flutter-version: '3.13.2'
|
||||||
|
|
||||||
|
- name: Build Flutter to Web
|
||||||
|
run: |
|
||||||
|
cd classic/frontend
|
||||||
|
flutter build web --base-href /app/
|
||||||
|
|
||||||
|
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||||
|
# if: github.event_name == 'push'
|
||||||
|
# run: |
|
||||||
|
# git config --local user.email "action@github.com"
|
||||||
|
# git config --local user.name "GitHub Action"
|
||||||
|
# git add classic/frontend/build/web
|
||||||
|
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||||
|
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||||
|
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||||
|
|
||||||
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
|
if: github.event_name == 'push'
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
add-paths: classic/frontend/build/web
|
||||||
|
base: ${{ github.ref_name }}
|
||||||
|
branch: ${{ env.BUILD_BRANCH }}
|
||||||
|
delete-branch: true
|
||||||
|
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||||
|
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||||
|
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,9 +7,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
@@ -18,9 +16,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
@@ -31,13 +27,44 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
get-changed-parts:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- id: changes-in
|
||||||
|
name: Determine affected subprojects
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
original_autogpt:
|
||||||
|
- classic/original_autogpt/autogpt/**
|
||||||
|
- classic/original_autogpt/tests/**
|
||||||
|
- classic/original_autogpt/poetry.lock
|
||||||
|
forge:
|
||||||
|
- classic/forge/forge/**
|
||||||
|
- classic/forge/tests/**
|
||||||
|
- classic/forge/poetry.lock
|
||||||
|
benchmark:
|
||||||
|
- classic/benchmark/agbenchmark/**
|
||||||
|
- classic/benchmark/tests/**
|
||||||
|
- classic/benchmark/poetry.lock
|
||||||
|
outputs:
|
||||||
|
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -54,31 +81,42 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Lint
|
# Lint
|
||||||
|
|
||||||
- name: Lint (isort)
|
- name: Lint (isort)
|
||||||
run: poetry run isort --check .
|
run: poetry run isort --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Black)
|
- name: Lint (Black)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run black --check .
|
run: poetry run black --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Flake8)
|
- name: Lint (Flake8)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run flake8 .
|
run: poetry run flake8 .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
types:
|
types:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -95,16 +133,19 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Typecheck
|
# Typecheck
|
||||||
|
|
||||||
- name: Typecheck
|
- name: Typecheck
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run pyright
|
run: poetry run pyright
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -3,7 +3,6 @@
|
|||||||
classic/original_autogpt/keys.py
|
classic/original_autogpt/keys.py
|
||||||
classic/original_autogpt/*.json
|
classic/original_autogpt/*.json
|
||||||
auto_gpt_workspace/*
|
auto_gpt_workspace/*
|
||||||
.autogpt/
|
|
||||||
*.mpeg
|
*.mpeg
|
||||||
.env
|
.env
|
||||||
# Root .env files
|
# Root .env files
|
||||||
@@ -160,10 +159,6 @@ CURRENT_BULLETIN.md
|
|||||||
|
|
||||||
# AgBenchmark
|
# AgBenchmark
|
||||||
classic/benchmark/agbenchmark/reports/
|
classic/benchmark/agbenchmark/reports/
|
||||||
classic/reports/
|
|
||||||
classic/direct_benchmark/reports/
|
|
||||||
classic/.benchmark_workspaces/
|
|
||||||
classic/direct_benchmark/.benchmark_workspaces/
|
|
||||||
|
|
||||||
# Nodejs
|
# Nodejs
|
||||||
package-lock.json
|
package-lock.json
|
||||||
@@ -182,10 +177,6 @@ autogpt_platform/backend/settings.py
|
|||||||
|
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
**/.claude/settings.local.json
|
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
|
||||||
# Test database
|
|
||||||
test.db
|
|
||||||
|
|||||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||||
|
path = classic/forge/tests/vcr_cassettes
|
||||||
|
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||||
@@ -43,10 +43,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic
|
alias: poetry-install-classic-autogpt
|
||||||
entry: poetry -C classic install
|
entry: poetry -C classic/original_autogpt install
|
||||||
files: ^classic/poetry\.lock$
|
# include forge source (since it's a path dependency)
|
||||||
|
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Forge
|
||||||
|
alias: poetry-install-classic-forge
|
||||||
|
entry: poetry -C classic/forge install
|
||||||
|
files: ^classic/forge/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
|
alias: poetry-install-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark install
|
||||||
|
files: ^classic/benchmark/poetry\.lock$
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
@@ -97,10 +116,26 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
|
|
||||||
- id: isort
|
- id: isort
|
||||||
name: Lint (isort) - Classic
|
name: Lint (isort) - Classic - AutoGPT
|
||||||
alias: isort-classic
|
alias: isort-classic-autogpt
|
||||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Forge
|
||||||
|
alias: isort-classic-forge
|
||||||
|
entry: poetry -P classic/forge run isort -p forge
|
||||||
|
files: ^classic/forge/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Benchmark
|
||||||
|
alias: isort-classic-benchmark
|
||||||
|
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||||
|
files: ^classic/benchmark/
|
||||||
types: [file, python]
|
types: [file, python]
|
||||||
language: system
|
language: system
|
||||||
|
|
||||||
@@ -114,13 +149,26 @@ repos:
|
|||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.0.0
|
rev: 7.0.0
|
||||||
# Use consolidated flake8 config at classic/.flake8
|
# To have flake8 load the config of the individual subprojects, we have to call
|
||||||
|
# them separately.
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
name: Lint (Flake8) - Classic
|
name: Lint (Flake8) - Classic - AutoGPT
|
||||||
alias: flake8-classic
|
alias: flake8-classic-autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||||
args: [--config=classic/.flake8]
|
args: [--config=classic/original_autogpt/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Forge
|
||||||
|
alias: flake8-classic-forge
|
||||||
|
files: ^classic/forge/(forge|tests)/
|
||||||
|
args: [--config=classic/forge/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Benchmark
|
||||||
|
alias: flake8-classic-benchmark
|
||||||
|
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||||
|
args: [--config=classic/benchmark/.flake8]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
@@ -156,10 +204,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: pyright
|
- id: pyright
|
||||||
name: Typecheck - Classic
|
name: Typecheck - Classic - AutoGPT
|
||||||
alias: pyright-classic
|
alias: pyright-classic-autogpt
|
||||||
entry: poetry -C classic run pyright
|
entry: poetry -C classic/original_autogpt run pyright
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||||
|
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Forge
|
||||||
|
alias: pyright-classic-forge
|
||||||
|
entry: poetry -C classic/forge run pyright
|
||||||
|
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Benchmark
|
||||||
|
alias: pyright-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark run pyright
|
||||||
|
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
|||||||
### Updated Setup Instructions:
|
### Updated Setup Instructions:
|
||||||
We've moved to a fully maintained and regularly updated documentation site.
|
We've moved to a fully maintained and regularly updated documentation site.
|
||||||
|
|
||||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||||
|
|
||||||
|
|
||||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||||
|
|||||||
@@ -17,6 +17,14 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
|
|||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# SSE response headers for streaming
|
||||||
|
SSE_RESPONSE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,6 +40,60 @@ async def _validate_and_get_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
is_user_message: bool = True,
|
||||||
|
context: dict[str, str] | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create SSE event generator for chat streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Chat session ID
|
||||||
|
message: User message to process
|
||||||
|
user_id: Optional authenticated user ID
|
||||||
|
session: Pre-fetched chat session
|
||||||
|
is_user_message: Whether the message is from a user
|
||||||
|
context: Optional context dict with url and content
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE-formatted chunks from the chat completion stream
|
||||||
|
"""
|
||||||
|
chunk_count = 0
|
||||||
|
first_chunk_type: str | None = None
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
message,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
if chunk_count < 3:
|
||||||
|
logger.info(
|
||||||
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_type": str(chunk.type),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
logger.info(
|
||||||
|
"Chat stream completed",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"first_chunk_type": first_chunk_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["chat"],
|
tags=["chat"],
|
||||||
)
|
)
|
||||||
@@ -221,49 +283,17 @@ async def stream_chat_post(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=request.message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
context=request.context,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -295,48 +325,16 @@ async def stream_chat_get(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
message,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.util.prompt import CompressResult
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from langfuse import get_client
|
from langfuse import get_client
|
||||||
from openai import (
|
from openai import (
|
||||||
@@ -15,7 +19,13 @@ from openai import (
|
|||||||
PermissionDeniedError,
|
PermissionDeniedError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
)
|
)
|
||||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
from openai.types.chat import (
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionStreamOptionsParam,
|
||||||
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
)
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
@@ -794,207 +804,58 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
|||||||
return "not available in your region" in str(error).lower()
|
return "not available in your region" in str(error).lower()
|
||||||
|
|
||||||
|
|
||||||
async def _summarize_messages(
|
async def _manage_context_window(
|
||||||
messages: list,
|
messages: list,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
timeout: float = 30.0,
|
) -> "CompressResult":
|
||||||
) -> str:
|
"""
|
||||||
"""Summarize a list of messages into concise context.
|
Manage context window using the unified compress_context function.
|
||||||
|
|
||||||
Uses the same model as the chat for higher quality summaries.
|
This is a thin wrapper that creates an OpenAI client for summarization
|
||||||
|
and delegates to the shared compression logic in prompt.py.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts to summarize
|
messages: List of messages in OpenAI format
|
||||||
model: Model to use for summarization (same as chat model)
|
model: Model name for token counting and summarization
|
||||||
api_key: API key for OpenAI client
|
api_key: API key for summarization calls
|
||||||
base_url: Base URL for OpenAI client
|
base_url: Base URL for summarization calls
|
||||||
timeout: Request timeout in seconds (default: 30.0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Summarized text
|
CompressResult with compacted messages and metadata
|
||||||
"""
|
"""
|
||||||
# Format messages for summarization
|
|
||||||
conversation = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
# Include user, assistant, and tool messages (tool outputs are important context)
|
|
||||||
if content and role in ("user", "assistant", "tool"):
|
|
||||||
conversation.append(f"{role.upper()}: {content}")
|
|
||||||
|
|
||||||
conversation_text = "\n\n".join(conversation)
|
|
||||||
|
|
||||||
# Handle empty conversation
|
|
||||||
if not conversation_text:
|
|
||||||
return "No conversation history available."
|
|
||||||
|
|
||||||
# Truncate conversation to fit within summarization model's context
|
|
||||||
# gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety
|
|
||||||
MAX_CHARS = 100_000
|
|
||||||
if len(conversation_text) > MAX_CHARS:
|
|
||||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
|
||||||
|
|
||||||
# Call LLM to summarize
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
summarization_client = openai.AsyncOpenAI(
|
from backend.util.prompt import compress_context
|
||||||
api_key=api_key, base_url=base_url, timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await summarization_client.chat.completions.create(
|
# Convert messages to dict format
|
||||||
model=model,
|
messages_dict = []
|
||||||
messages=[
|
for msg in messages:
|
||||||
{
|
if isinstance(msg, dict):
|
||||||
"role": "system",
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
"content": (
|
else:
|
||||||
"Create a detailed summary of the conversation so far. "
|
msg_dict = dict(msg)
|
||||||
"This summary will be used as context when continuing the conversation.\n\n"
|
messages_dict.append(msg_dict)
|
||||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
|
||||||
"- User requests and their explicit goals\n"
|
|
||||||
"- Your approach and key decisions made\n"
|
|
||||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
|
||||||
"- Errors encountered and resolutions applied\n\n"
|
|
||||||
"You MUST include ALL of the following sections:\n\n"
|
|
||||||
"## 1. Primary Request and Intent\n"
|
|
||||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
|
||||||
"## 2. Key Technical Concepts\n"
|
|
||||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
|
||||||
"## 3. Files and Resources Involved\n"
|
|
||||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
|
||||||
"## 4. Errors and Fixes\n"
|
|
||||||
"Problems encountered, error messages, and their resolutions. "
|
|
||||||
"Include any user feedback on fixes.\n\n"
|
|
||||||
"## 5. Problem Solving\n"
|
|
||||||
"Issues that have been resolved and how they were addressed.\n\n"
|
|
||||||
"## 6. All User Messages\n"
|
|
||||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
|
||||||
"## 7. Pending Tasks\n"
|
|
||||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
|
||||||
"## 8. Current Work\n"
|
|
||||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
|
||||||
"## 9. Next Steps\n"
|
|
||||||
"What should happen next, aligned with the user's most recent requests. "
|
|
||||||
"Include verbatim quotes of recent instructions if relevant."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
|
||||||
],
|
|
||||||
max_tokens=1500,
|
|
||||||
temperature=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
summary = response.choices[0].message.content
|
# Only create client if api_key is provided (enables summarization)
|
||||||
return summary or "No summary available."
|
# Use context manager to avoid socket leaks
|
||||||
|
if api_key:
|
||||||
|
async with openai.AsyncOpenAI(
|
||||||
def _ensure_tool_pairs_intact(
|
api_key=api_key, base_url=base_url, timeout=30.0
|
||||||
recent_messages: list[dict],
|
) as client:
|
||||||
all_messages: list[dict],
|
return await compress_context(
|
||||||
start_index: int,
|
messages=messages_dict,
|
||||||
) -> list[dict]:
|
model=model,
|
||||||
"""
|
client=client,
|
||||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
|
||||||
|
|
||||||
When slicing messages for context compaction, a naive slice can separate
|
|
||||||
an assistant message containing tool_calls from its corresponding tool
|
|
||||||
response messages. This causes API validation errors (e.g., Anthropic's
|
|
||||||
"unexpected tool_use_id found in tool_result blocks").
|
|
||||||
|
|
||||||
This function checks for orphan tool responses in the slice and extends
|
|
||||||
backwards to include their corresponding assistant messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recent_messages: The sliced messages to validate
|
|
||||||
all_messages: The complete message list (for looking up missing assistants)
|
|
||||||
start_index: The index in all_messages where recent_messages begins
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A potentially extended list of messages with tool pairs intact
|
|
||||||
"""
|
|
||||||
if not recent_messages:
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Collect all tool_call_ids from assistant messages in the slice
|
|
||||||
available_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tc_id = tc.get("id")
|
|
||||||
if tc_id:
|
|
||||||
available_tool_call_ids.add(tc_id)
|
|
||||||
|
|
||||||
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
|
||||||
orphan_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tc_id = msg.get("tool_call_id")
|
|
||||||
if tc_id and tc_id not in available_tool_call_ids:
|
|
||||||
orphan_tool_call_ids.add(tc_id)
|
|
||||||
|
|
||||||
if not orphan_tool_call_ids:
|
|
||||||
# No orphans, slice is valid
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Find the assistant messages that contain the orphan tool_call_ids
|
|
||||||
# Search backwards from start_index in all_messages
|
|
||||||
messages_to_prepend: list[dict] = []
|
|
||||||
for i in range(start_index - 1, -1, -1):
|
|
||||||
msg = all_messages[i]
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
|
||||||
if msg_tool_ids & orphan_tool_call_ids:
|
|
||||||
# This assistant message has tool_calls we need
|
|
||||||
# Also collect its contiguous tool responses that follow it
|
|
||||||
assistant_and_responses: list[dict] = [msg]
|
|
||||||
|
|
||||||
# Scan forward from this assistant to collect tool responses
|
|
||||||
for j in range(i + 1, start_index):
|
|
||||||
following_msg = all_messages[j]
|
|
||||||
if following_msg.get("role") == "tool":
|
|
||||||
tool_id = following_msg.get("tool_call_id")
|
|
||||||
if tool_id and tool_id in msg_tool_ids:
|
|
||||||
assistant_and_responses.append(following_msg)
|
|
||||||
else:
|
|
||||||
# Stop at first non-tool message
|
|
||||||
break
|
|
||||||
|
|
||||||
# Prepend the assistant and its tool responses (maintain order)
|
|
||||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
|
||||||
# Mark these as found
|
|
||||||
orphan_tool_call_ids -= msg_tool_ids
|
|
||||||
# Also add this assistant's tool_call_ids to available set
|
|
||||||
available_tool_call_ids |= msg_tool_ids
|
|
||||||
|
|
||||||
if not orphan_tool_call_ids:
|
|
||||||
# Found all missing assistants
|
|
||||||
break
|
|
||||||
|
|
||||||
if orphan_tool_call_ids:
|
|
||||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
|
||||||
# This shouldn't happen in normal operation but handles edge cases
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
|
||||||
"Removing orphan tool responses."
|
|
||||||
)
|
|
||||||
recent_messages = [
|
|
||||||
msg
|
|
||||||
for msg in recent_messages
|
|
||||||
if not (
|
|
||||||
msg.get("role") == "tool"
|
|
||||||
and msg.get("tool_call_id") in orphan_tool_call_ids
|
|
||||||
)
|
)
|
||||||
]
|
else:
|
||||||
|
# No API key - use truncation-only mode
|
||||||
if messages_to_prepend:
|
return await compress_context(
|
||||||
logger.info(
|
messages=messages_dict,
|
||||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
model=model,
|
||||||
f"tool_call/tool_response pairs"
|
client=None,
|
||||||
)
|
)
|
||||||
return messages_to_prepend + recent_messages
|
|
||||||
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
@@ -1022,11 +883,8 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
logger.info("Starting pure chat stream")
|
||||||
|
|
||||||
# Build messages with system prompt prepended
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
@@ -1034,314 +892,38 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
token_count = 0 # Initialize for exception handler
|
context_result = await _manage_context_window(
|
||||||
try:
|
messages=messages,
|
||||||
from backend.util.prompt import estimate_token_count
|
model=model,
|
||||||
|
api_key=config.api_key,
|
||||||
|
base_url=config.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to dict for token counting
|
if context_result.error:
|
||||||
# OpenAI message types are TypedDicts, so they're already dict-like
|
if "System prompt dropped" in context_result.error:
|
||||||
messages_dict = []
|
# Warning only - continue with reduced context
|
||||||
for msg in messages:
|
|
||||||
# TypedDict objects are already dicts, just filter None values
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
|
||||||
else:
|
|
||||||
# Fallback for unexpected types
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
# Estimate tokens using appropriate tokenizer
|
|
||||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
|
||||||
token_count_model = model
|
|
||||||
if "/" in model:
|
|
||||||
# Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5")
|
|
||||||
token_count_model = model.split("/")[-1]
|
|
||||||
|
|
||||||
# For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer
|
|
||||||
# Most modern LLMs have similar tokenization (~1 token per 4 chars)
|
|
||||||
if "claude" in token_count_model.lower() or not any(
|
|
||||||
known in token_count_model.lower()
|
|
||||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
|
||||||
):
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
|
|
||||||
# Attempt token counting with error handling
|
|
||||||
try:
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
except Exception as token_error:
|
|
||||||
# If token counting fails, use gpt-4o as fallback approximation
|
|
||||||
logger.warning(
|
|
||||||
f"Token counting failed for model {token_count_model}: {token_error}. "
|
|
||||||
"Using gpt-4o approximation."
|
|
||||||
)
|
|
||||||
token_count = estimate_token_count(messages_dict, model="gpt-4o")
|
|
||||||
|
|
||||||
# If over threshold, summarize old messages
|
|
||||||
if token_count > 120_000:
|
|
||||||
KEEP_RECENT = 15
|
|
||||||
|
|
||||||
# Check if we have a system prompt at the start
|
|
||||||
has_system_prompt = (
|
|
||||||
len(messages) > 0 and messages[0].get("role") == "system"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Always attempt mitigation when over limit, even with few messages
|
|
||||||
if messages:
|
|
||||||
# Split messages based on whether system prompt exists
|
|
||||||
# Calculate start index for the slice
|
|
||||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
|
||||||
recent_messages = messages_dict[-KEEP_RECENT:]
|
|
||||||
|
|
||||||
# Ensure tool_call/tool_response pairs stay together
|
|
||||||
# This prevents API errors from orphan tool responses
|
|
||||||
recent_messages = _ensure_tool_pairs_intact(
|
|
||||||
recent_messages, messages_dict, slice_start
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_system_prompt:
|
|
||||||
# Keep system prompt separate, summarize everything between system and recent
|
|
||||||
system_msg = messages[0]
|
|
||||||
old_messages_dict = messages_dict[1:-KEEP_RECENT]
|
|
||||||
else:
|
|
||||||
# No system prompt, summarize everything except recent
|
|
||||||
system_msg = None
|
|
||||||
old_messages_dict = messages_dict[:-KEEP_RECENT]
|
|
||||||
|
|
||||||
# Summarize any non-empty old messages (no minimum threshold)
|
|
||||||
# If we're over the token limit, we need to compress whatever we can
|
|
||||||
if old_messages_dict:
|
|
||||||
# Summarize old messages using the same model as chat
|
|
||||||
summary_text = await _summarize_messages(
|
|
||||||
old_messages_dict,
|
|
||||||
model=model,
|
|
||||||
api_key=config.api_key,
|
|
||||||
base_url=config.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build new message list
|
|
||||||
# Use assistant role (not system) to prevent privilege escalation
|
|
||||||
# of user-influenced content to instruction-level authority
|
|
||||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
|
||||||
|
|
||||||
summary_msg = ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=(
|
|
||||||
"[Previous conversation summary — for context only]: "
|
|
||||||
f"{summary_text}"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rebuild messages based on whether we have a system prompt
|
|
||||||
if has_system_prompt:
|
|
||||||
# system_prompt + summary + recent_messages
|
|
||||||
messages = [system_msg, summary_msg] + recent_messages
|
|
||||||
else:
|
|
||||||
# summary + recent_messages (no original system prompt)
|
|
||||||
messages = [summary_msg] + recent_messages
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Context summarized: {token_count} tokens, "
|
|
||||||
f"summarized {len(old_messages_dict)} old messages, "
|
|
||||||
f"kept last {KEEP_RECENT} messages"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback: If still over limit after summarization, progressively drop recent messages
|
|
||||||
# This handles edge cases where recent messages are extremely large
|
|
||||||
new_messages_dict = []
|
|
||||||
for msg in messages:
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
new_messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
new_messages_dict, model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_token_count > 120_000:
|
|
||||||
# Still over limit - progressively reduce KEEP_RECENT
|
|
||||||
logger.warning(
|
|
||||||
f"Still over limit after summarization: {new_token_count} tokens. "
|
|
||||||
"Reducing number of recent messages kept."
|
|
||||||
)
|
|
||||||
|
|
||||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
|
||||||
if keep_count == 0:
|
|
||||||
# Try with just system prompt + summary (no recent messages)
|
|
||||||
if has_system_prompt:
|
|
||||||
messages = [system_msg, summary_msg]
|
|
||||||
else:
|
|
||||||
messages = [summary_msg]
|
|
||||||
logger.info(
|
|
||||||
"Trying with 0 recent messages (system + summary only)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Slice from ORIGINAL recent_messages to avoid duplicating summary
|
|
||||||
reduced_recent = (
|
|
||||||
recent_messages[-keep_count:]
|
|
||||||
if len(recent_messages) >= keep_count
|
|
||||||
else recent_messages
|
|
||||||
)
|
|
||||||
# Ensure tool pairs stay intact in the reduced slice
|
|
||||||
reduced_slice_start = max(
|
|
||||||
0, len(recent_messages) - keep_count
|
|
||||||
)
|
|
||||||
reduced_recent = _ensure_tool_pairs_intact(
|
|
||||||
reduced_recent, recent_messages, reduced_slice_start
|
|
||||||
)
|
|
||||||
if has_system_prompt:
|
|
||||||
messages = [
|
|
||||||
system_msg,
|
|
||||||
summary_msg,
|
|
||||||
] + reduced_recent
|
|
||||||
else:
|
|
||||||
messages = [summary_msg] + reduced_recent
|
|
||||||
|
|
||||||
new_messages_dict = []
|
|
||||||
for msg in messages:
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {
|
|
||||||
k: v for k, v in msg.items() if v is not None
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
new_messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
new_messages_dict, model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_token_count <= 120_000:
|
|
||||||
logger.info(
|
|
||||||
f"Reduced to {keep_count} recent messages, "
|
|
||||||
f"now {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
|
||||||
f"Final count: {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
|
||||||
# This should only happen if summary itself is massive
|
|
||||||
if has_system_prompt and len(messages) > 1:
|
|
||||||
messages = messages[1:] # Drop system prompt
|
|
||||||
logger.critical(
|
|
||||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
|
||||||
"Behavioral consistency may be affected."
|
|
||||||
)
|
|
||||||
# Yield error to user
|
|
||||||
yield StreamError(
|
|
||||||
errorText=(
|
|
||||||
"Warning: System prompt dropped due to size constraints. "
|
|
||||||
"Assistant behavior may be affected."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No old messages to summarize - all messages are "recent"
|
|
||||||
# Apply progressive truncation to reduce token count
|
|
||||||
logger.warning(
|
|
||||||
f"Token count {token_count} exceeds threshold but no old messages to summarize. "
|
|
||||||
f"Applying progressive truncation to recent messages."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a base list excluding system prompt to avoid duplication
|
|
||||||
# This is the pool of messages we'll slice from in the loop
|
|
||||||
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
|
||||||
base_msgs = (
|
|
||||||
messages_dict[1:] if has_system_prompt else messages_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try progressively smaller keep counts
|
|
||||||
new_token_count = token_count # Initialize with current count
|
|
||||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
|
||||||
if keep_count == 0:
|
|
||||||
# Try with just system prompt (no recent messages)
|
|
||||||
if has_system_prompt:
|
|
||||||
messages = [system_msg]
|
|
||||||
logger.info(
|
|
||||||
"Trying with 0 recent messages (system prompt only)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No system prompt and no recent messages = empty messages list
|
|
||||||
# This is invalid, skip this iteration
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if len(base_msgs) < keep_count:
|
|
||||||
continue # Skip if we don't have enough messages
|
|
||||||
|
|
||||||
# Slice from base_msgs to get recent messages (without system prompt)
|
|
||||||
recent_messages = base_msgs[-keep_count:]
|
|
||||||
|
|
||||||
# Ensure tool pairs stay intact in the reduced slice
|
|
||||||
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
|
||||||
recent_messages = _ensure_tool_pairs_intact(
|
|
||||||
recent_messages, base_msgs, reduced_slice_start
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_system_prompt:
|
|
||||||
messages = [system_msg] + recent_messages
|
|
||||||
else:
|
|
||||||
messages = recent_messages
|
|
||||||
|
|
||||||
new_messages_dict = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg is None:
|
|
||||||
continue # Skip None messages (type safety)
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {
|
|
||||||
k: v for k, v in msg.items() if v is not None
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
new_messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
new_messages_dict, model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
if new_token_count <= 120_000:
|
|
||||||
logger.info(
|
|
||||||
f"Reduced to {keep_count} recent messages, "
|
|
||||||
f"now {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Even with 0 messages still over limit
|
|
||||||
logger.error(
|
|
||||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
|
||||||
f"Final count: {new_token_count} tokens. Messages may be extremely large."
|
|
||||||
)
|
|
||||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
|
||||||
if has_system_prompt and len(messages) > 1:
|
|
||||||
messages = messages[1:] # Drop system prompt
|
|
||||||
logger.critical(
|
|
||||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
|
||||||
"Behavioral consistency may be affected."
|
|
||||||
)
|
|
||||||
# Yield error to user
|
|
||||||
yield StreamError(
|
|
||||||
errorText=(
|
|
||||||
"Warning: System prompt dropped due to size constraints. "
|
|
||||||
"Assistant behavior may be affected."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Context summarization failed: {e}", exc_info=True)
|
|
||||||
# If we were over the token limit, yield error to user
|
|
||||||
# Don't silently continue with oversized messages that will fail
|
|
||||||
if token_count > 120_000:
|
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
errorText=(
|
errorText=(
|
||||||
f"Unable to manage context window (token limit exceeded: {token_count} tokens). "
|
"Warning: System prompt dropped due to size constraints. "
|
||||||
"Context summarization failed. Please start a new conversation."
|
"Assistant behavior may be affected."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Any other error - abort to prevent failed LLM calls
|
||||||
|
yield StreamError(
|
||||||
|
errorText=(
|
||||||
|
f"Context window management failed: {context_result.error}. "
|
||||||
|
"Please start a new conversation."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
# Otherwise, continue with original messages (under limit)
|
|
||||||
|
messages = context_result.messages
|
||||||
|
if context_result.was_compacted:
|
||||||
|
logger.info(
|
||||||
|
f"Context compacted for streaming: {context_result.token_count} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
# Loop to handle tool calls and continue conversation
|
# Loop to handle tool calls and continue conversation
|
||||||
while True:
|
while True:
|
||||||
@@ -1369,14 +951,6 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
# Create the stream with proper types
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionStreamOptionsParam,
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1834,6 +1408,11 @@ async def _execute_long_running_tool(
|
|||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
result=error_response.model_dump_json(),
|
result=error_response.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
# Generate LLM continuation so user sees explanation even for errors
|
||||||
|
try:
|
||||||
|
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||||
|
except Exception as llm_err:
|
||||||
|
logger.warning(f"Failed to generate LLM continuation for error: {llm_err}")
|
||||||
finally:
|
finally:
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
@@ -1895,17 +1474,36 @@ async def _generate_llm_continuation(
|
|||||||
# Build system prompt
|
# Build system prompt
|
||||||
system_prompt, _ = await _build_system_prompt(user_id)
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
# Build messages in OpenAI format
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
)
|
)
|
||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
|
# Apply context window management to prevent oversized requests
|
||||||
|
context_result = await _manage_context_window(
|
||||||
|
messages=messages,
|
||||||
|
model=config.model,
|
||||||
|
api_key=config.api_key,
|
||||||
|
base_url=config.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
if context_result.error and "System prompt dropped" not in context_result.error:
|
||||||
|
logger.error(
|
||||||
|
f"Context window management failed for session {session_id}: "
|
||||||
|
f"{context_result.error} (tokens={context_result.token_count})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
messages = context_result.messages
|
||||||
|
if context_result.was_compacted:
|
||||||
|
logger.info(
|
||||||
|
f"Context compacted for LLM continuation: "
|
||||||
|
f"{context_result.token_count} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
# Build extra_body for tracing
|
# Build extra_body for tracing
|
||||||
extra_body: dict[str, Any] = {
|
extra_body: dict[str, Any] = {
|
||||||
"posthogProperties": {
|
"posthogProperties": {
|
||||||
@@ -1918,19 +1516,54 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Make non-streaming LLM call (no tools - just text response)
|
retry_count = 0
|
||||||
from typing import cast
|
last_error: Exception | None = None
|
||||||
|
response = None
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
while retry_count <= MAX_RETRIES:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Generating LLM continuation for session {session_id}"
|
||||||
|
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
# No tools parameter = text-only response (no tool calls)
|
response = await client.chat.completions.create(
|
||||||
response = await client.chat.completions.create(
|
model=config.model,
|
||||||
model=config.model,
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
extra_body=extra_body,
|
||||||
extra_body=extra_body,
|
)
|
||||||
)
|
last_error = None # Clear any previous error on success
|
||||||
|
break # Success, exit retry loop
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||||
|
retry_count += 1
|
||||||
|
delay = min(
|
||||||
|
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
||||||
|
MAX_DELAY_SECONDS,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Retryable error in LLM continuation: {e!s}. "
|
||||||
|
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Non-retryable error - log and exit gracefully
|
||||||
|
logger.error(
|
||||||
|
f"Non-retryable error in LLM continuation: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if response.choices and response.choices[0].message.content:
|
if last_error:
|
||||||
|
logger.error(
|
||||||
|
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
|
||||||
|
f"Last error: {last_error!s}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if response and response.choices and response.choices[0].message.content:
|
||||||
assistant_content = response.choices[0].message.content
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
|||||||
@@ -2,30 +2,54 @@
|
|||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
|
AgentJsonValidationError,
|
||||||
|
AgentSummary,
|
||||||
|
DecompositionResult,
|
||||||
|
DecompositionStep,
|
||||||
|
LibraryAgentSummary,
|
||||||
|
MarketplaceAgentSummary,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
|
extract_search_terms_from_steps,
|
||||||
|
extract_uuids_from_text,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
|
get_library_agent_by_graph_id,
|
||||||
|
get_library_agent_by_id,
|
||||||
|
get_library_agents_for_generation,
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
|
search_marketplace_agents_for_generation,
|
||||||
)
|
)
|
||||||
from .errors import get_user_message_for_error
|
from .errors import get_user_message_for_error
|
||||||
from .service import health_check as check_external_service_health
|
from .service import health_check as check_external_service_health
|
||||||
from .service import is_external_service_configured
|
from .service import is_external_service_configured
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Core functions
|
"AgentGeneratorNotConfiguredError",
|
||||||
|
"AgentJsonValidationError",
|
||||||
|
"AgentSummary",
|
||||||
|
"DecompositionResult",
|
||||||
|
"DecompositionStep",
|
||||||
|
"LibraryAgentSummary",
|
||||||
|
"MarketplaceAgentSummary",
|
||||||
|
"check_external_service_health",
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
|
"enrich_library_agents_from_steps",
|
||||||
|
"extract_search_terms_from_steps",
|
||||||
|
"extract_uuids_from_text",
|
||||||
"generate_agent",
|
"generate_agent",
|
||||||
"generate_agent_patch",
|
"generate_agent_patch",
|
||||||
"save_agent_to_library",
|
|
||||||
"get_agent_as_json",
|
"get_agent_as_json",
|
||||||
"json_to_graph",
|
"get_all_relevant_agents_for_generation",
|
||||||
# Exceptions
|
"get_library_agent_by_graph_id",
|
||||||
"AgentGeneratorNotConfiguredError",
|
"get_library_agent_by_id",
|
||||||
# Service
|
"get_library_agents_for_generation",
|
||||||
"is_external_service_configured",
|
|
||||||
"check_external_service_health",
|
|
||||||
# Error handling
|
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
|
"is_external_service_configured",
|
||||||
|
"json_to_graph",
|
||||||
|
"save_agent_to_library",
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,11 +1,22 @@
|
|||||||
"""Core agent generation functions."""
|
"""Core agent generation functions."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.api.features.store import db as store_db
|
||||||
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
get_store_listed_graphs,
|
||||||
|
)
|
||||||
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
@@ -16,6 +27,74 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionSummary(TypedDict):
|
||||||
|
"""Summary of a single execution for quality assessment."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
correctness_score: NotRequired[float]
|
||||||
|
activity_summary: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentSummary(TypedDict):
|
||||||
|
"""Summary of a library agent for sub-agent composition.
|
||||||
|
|
||||||
|
Includes recent executions to help the LLM decide whether to use this agent.
|
||||||
|
Each execution shows status, correctness_score (0-1), and activity_summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_id: str
|
||||||
|
graph_version: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
output_schema: dict[str, Any]
|
||||||
|
recent_executions: NotRequired[list[ExecutionSummary]]
|
||||||
|
|
||||||
|
|
||||||
|
class MarketplaceAgentSummary(TypedDict):
|
||||||
|
"""Summary of a marketplace agent for sub-agent composition."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
sub_heading: str
|
||||||
|
creator: str
|
||||||
|
is_marketplace_agent: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionStep(TypedDict, total=False):
|
||||||
|
"""A single step in decomposed instructions."""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
action: str
|
||||||
|
block_name: str
|
||||||
|
tool: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class DecompositionResult(TypedDict, total=False):
|
||||||
|
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
steps: list[DecompositionStep]
|
||||||
|
questions: list[dict[str, Any]]
|
||||||
|
error: str
|
||||||
|
error_type: str
|
||||||
|
|
||||||
|
|
||||||
|
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_dict_list(
|
||||||
|
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||||
|
if agents is None:
|
||||||
|
return None
|
||||||
|
return [dict(a) for a in agents]
|
||||||
|
|
||||||
|
|
||||||
class AgentGeneratorNotConfiguredError(Exception):
|
class AgentGeneratorNotConfiguredError(Exception):
|
||||||
"""Raised when the external Agent Generator service is not configured."""
|
"""Raised when the external Agent Generator service is not configured."""
|
||||||
@@ -36,15 +115,422 @@ def _check_service_configured() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids_from_text(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found in the text (lowercase)
|
||||||
|
"""
|
||||||
|
matches = _UUID_PATTERN.findall(text)
|
||||||
|
return list({m.lower() for m in matches})
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agent_by_id(
|
||||||
|
user_id: str, agent_id: str
|
||||||
|
) -> LibraryAgentSummary | None:
|
||||||
|
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
This function tries multiple lookup strategies:
|
||||||
|
1. First tries to find by graph_id (AgentGraph primary key)
|
||||||
|
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
This handles both cases:
|
||||||
|
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||||
|
- User provides library agent ID (e.g., from library URL)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LibraryAgentSummary if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
async def get_library_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
max_results: int = 15,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Fetch user's library agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Uses search-based fetching to return relevant agents instead of all agents.
|
||||||
|
This is more scalable for users with large libraries.
|
||||||
|
|
||||||
|
Includes recent_executions list to help the LLM assess agent quality:
|
||||||
|
- Each execution has status, correctness_score (0-1), and activity_summary
|
||||||
|
- This gives the LLM concrete examples of recent performance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
max_results: Maximum number of agents to return (default 15)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await library_db.list_library_agents(
|
||||||
|
user_id=user_id,
|
||||||
|
search_term=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
include_executions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in response.agents:
|
||||||
|
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
summary = LibraryAgentSummary(
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
graph_version=agent.graph_version,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=agent.input_schema,
|
||||||
|
output_schema=agent.output_schema,
|
||||||
|
)
|
||||||
|
if agent.recent_executions:
|
||||||
|
exec_summaries: list[ExecutionSummary] = []
|
||||||
|
for ex in agent.recent_executions:
|
||||||
|
exec_sum = ExecutionSummary(status=ex.status)
|
||||||
|
if ex.correctness_score is not None:
|
||||||
|
exec_sum["correctness_score"] = ex.correctness_score
|
||||||
|
if ex.activity_summary:
|
||||||
|
exec_sum["activity_summary"] = ex.activity_summary
|
||||||
|
exec_summaries.append(exec_sum)
|
||||||
|
summary["recent_executions"] = exec_summaries
|
||||||
|
results.append(summary)
|
||||||
|
return results
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def search_marketplace_agents_for_generation(
|
||||||
|
search_query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> list[LibraryAgentSummary]:
|
||||||
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
|
Fetches marketplace agents and their full schemas so they can be used
|
||||||
|
as sub-agents in generated workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query: Search term to find relevant public agents
|
||||||
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LibraryAgentSummary with full input/output schemas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await store_db.get_store_agents(
|
||||||
|
search_query=search_query,
|
||||||
|
page=1,
|
||||||
|
page_size=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_with_graphs = [
|
||||||
|
agent for agent in response.agents if agent.agent_graph_id
|
||||||
|
]
|
||||||
|
|
||||||
|
if not agents_with_graphs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||||
|
graphs = await get_store_listed_graphs(*graph_ids)
|
||||||
|
|
||||||
|
results: list[LibraryAgentSummary] = []
|
||||||
|
for agent in agents_with_graphs:
|
||||||
|
graph_id = agent.agent_graph_id
|
||||||
|
if graph_id and graph_id in graphs:
|
||||||
|
graph = graphs[graph_id]
|
||||||
|
results.append(
|
||||||
|
LibraryAgentSummary(
|
||||||
|
graph_id=graph.id,
|
||||||
|
graph_version=graph.version,
|
||||||
|
name=agent.agent_name,
|
||||||
|
description=agent.description,
|
||||||
|
input_schema=graph.input_schema,
|
||||||
|
output_schema=graph.output_schema,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_relevant_agents_for_generation(
|
||||||
|
user_id: str,
|
||||||
|
search_query: str | None = None,
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_library: bool = True,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_library_results: int = 15,
|
||||||
|
max_marketplace_results: int = 10,
|
||||||
|
) -> list[AgentSummary]:
|
||||||
|
"""Fetch relevant agents from library and/or marketplace.
|
||||||
|
|
||||||
|
Searches both user's library and marketplace by default.
|
||||||
|
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
search_query: Search term to find relevant agents (user's goal/description)
|
||||||
|
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||||
|
include_library: Whether to search user's library (default True)
|
||||||
|
include_marketplace: Whether to also search marketplace (default True)
|
||||||
|
max_library_results: Max library agents to return (default 15)
|
||||||
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentSummary with full schemas (both library and marketplace agents)
|
||||||
|
"""
|
||||||
|
agents: list[AgentSummary] = []
|
||||||
|
seen_graph_ids: set[str] = set()
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||||
|
for graph_id in mentioned_uuids:
|
||||||
|
if graph_id == exclude_graph_id:
|
||||||
|
continue
|
||||||
|
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||||
|
agent_graph_id = agent.get("graph_id") if agent else None
|
||||||
|
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(agent_graph_id)
|
||||||
|
logger.debug(
|
||||||
|
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_library:
|
||||||
|
library_agents = await get_library_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
max_results=max_library_results,
|
||||||
|
)
|
||||||
|
for agent in library_agents:
|
||||||
|
graph_id = agent.get("graph_id")
|
||||||
|
if graph_id and graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(graph_id)
|
||||||
|
|
||||||
|
if include_marketplace and search_query:
|
||||||
|
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||||
|
search_query=search_query,
|
||||||
|
max_results=max_marketplace_results,
|
||||||
|
)
|
||||||
|
for agent in marketplace_agents:
|
||||||
|
graph_id = agent.get("graph_id")
|
||||||
|
if graph_id and graph_id not in seen_graph_ids:
|
||||||
|
agents.append(agent)
|
||||||
|
seen_graph_ids.add(graph_id)
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def extract_search_terms_from_steps(
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract search terms from decomposed instruction steps.
|
||||||
|
|
||||||
|
Analyzes the decomposition result to extract relevant keywords
|
||||||
|
for additional library agent searches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique search terms extracted from steps
|
||||||
|
"""
|
||||||
|
search_terms: list[str] = []
|
||||||
|
|
||||||
|
if decomposition_result.get("type") != "instructions":
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
steps = decomposition_result.get("steps", [])
|
||||||
|
if not steps:
|
||||||
|
return search_terms
|
||||||
|
|
||||||
|
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||||
|
|
||||||
|
for step in steps:
|
||||||
|
for key in step_keys:
|
||||||
|
value = step.get(key) # type: ignore[union-attr]
|
||||||
|
if isinstance(value, str) and len(value) > 3:
|
||||||
|
search_terms.append(value)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_terms: list[str] = []
|
||||||
|
for term in search_terms:
|
||||||
|
term_lower = term.lower()
|
||||||
|
if term_lower not in seen:
|
||||||
|
seen.add(term_lower)
|
||||||
|
unique_terms.append(term)
|
||||||
|
|
||||||
|
return unique_terms
|
||||||
|
|
||||||
|
|
||||||
|
async def enrich_library_agents_from_steps(
|
||||||
|
user_id: str,
|
||||||
|
decomposition_result: DecompositionResult | dict[str, Any],
|
||||||
|
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||||
|
exclude_graph_id: str | None = None,
|
||||||
|
include_marketplace: bool = True,
|
||||||
|
max_additional_results: int = 10,
|
||||||
|
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||||
|
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||||
|
|
||||||
|
This implements two-phase search: after decomposition, we search for additional
|
||||||
|
relevant agents based on the specific steps identified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
decomposition_result: Result from decompose_goal containing steps
|
||||||
|
existing_agents: Already fetched library agents from initial search
|
||||||
|
exclude_graph_id: Optional graph ID to exclude
|
||||||
|
include_marketplace: Whether to also search marketplace
|
||||||
|
max_additional_results: Max additional agents per search term (default 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined list of library agents (existing + newly discovered)
|
||||||
|
"""
|
||||||
|
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
if not search_terms:
|
||||||
|
return existing_agents
|
||||||
|
|
||||||
|
existing_ids: set[str] = set()
|
||||||
|
existing_names: set[str] = set()
|
||||||
|
|
||||||
|
for agent in existing_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if agent_name and isinstance(agent_name, str):
|
||||||
|
existing_names.add(agent_name.lower())
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and isinstance(graph_id, str):
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||||
|
|
||||||
|
for term in search_terms[:3]:
|
||||||
|
try:
|
||||||
|
additional_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=term,
|
||||||
|
exclude_graph_id=exclude_graph_id,
|
||||||
|
include_marketplace=include_marketplace,
|
||||||
|
max_library_results=max_additional_results,
|
||||||
|
max_marketplace_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in additional_agents:
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
if not agent_name or not isinstance(agent_name, str):
|
||||||
|
continue
|
||||||
|
agent_name_lower = agent_name.lower()
|
||||||
|
|
||||||
|
if agent_name_lower in existing_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||||
|
if graph_id and graph_id in existing_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
all_agents.append(agent)
|
||||||
|
existing_names.add(agent_name_lower)
|
||||||
|
if graph_id and isinstance(graph_id, str):
|
||||||
|
existing_ids.add(graph_id)
|
||||||
|
|
||||||
|
except DatabaseError:
|
||||||
|
logger.error(f"Database error searching for agents with term '{term}'")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to search for additional agents with term '{term}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||||
|
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_agents
|
||||||
|
|
||||||
|
|
||||||
|
async def decompose_goal(
|
||||||
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
|
) -> DecompositionResult | None:
|
||||||
"""Break down a goal into steps or return clarifying questions.
|
"""Break down a goal into steps or return clarifying questions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
DecompositionResult with either:
|
||||||
- {"type": "clarifying_questions", "questions": [...]}
|
- {"type": "clarifying_questions", "questions": [...]}
|
||||||
- {"type": "instructions", "steps": [...]}
|
- {"type": "instructions", "steps": [...]}
|
||||||
Or None on error
|
Or None on error
|
||||||
@@ -54,14 +540,21 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||||
return await decompose_goal_external(description, context)
|
result = await decompose_goal_external(
|
||||||
|
description, context, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
return result # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
async def generate_agent(
|
||||||
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
@@ -71,12 +564,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(instructions)
|
result = await generate_agent_external(
|
||||||
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
if result:
|
if result:
|
||||||
# Check if it's an error response - pass through as-is
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
# Ensure required fields for successful agent generation
|
|
||||||
if "id" not in result:
|
if "id" not in result:
|
||||||
result["id"] = str(uuid.uuid4())
|
result["id"] = str(uuid.uuid4())
|
||||||
if "version" not in result:
|
if "version" not in result:
|
||||||
@@ -86,6 +579,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AgentJsonValidationError(Exception):
|
||||||
|
"""Raised when agent JSON is invalid or missing required fields."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||||
"""Convert agent JSON dict to Graph model.
|
"""Convert agent JSON dict to Graph model.
|
||||||
|
|
||||||
@@ -94,25 +593,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Graph ready for saving
|
Graph ready for saving
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||||
"""
|
"""
|
||||||
nodes = []
|
nodes = []
|
||||||
for n in agent_json.get("nodes", []):
|
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||||
|
block_id = n.get("block_id")
|
||||||
|
if not block_id:
|
||||||
|
node_id = n.get("id", f"index_{idx}")
|
||||||
|
raise AgentJsonValidationError(
|
||||||
|
f"Node '{node_id}' is missing required field 'block_id'"
|
||||||
|
)
|
||||||
node = Node(
|
node = Node(
|
||||||
id=n.get("id", str(uuid.uuid4())),
|
id=n.get("id", str(uuid.uuid4())),
|
||||||
block_id=n["block_id"],
|
block_id=block_id,
|
||||||
input_default=n.get("input_default", {}),
|
input_default=n.get("input_default", {}),
|
||||||
metadata=n.get("metadata", {}),
|
metadata=n.get("metadata", {}),
|
||||||
)
|
)
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
|
|
||||||
links = []
|
links = []
|
||||||
for link_data in agent_json.get("links", []):
|
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||||
|
source_id = link_data.get("source_id")
|
||||||
|
sink_id = link_data.get("sink_id")
|
||||||
|
source_name = link_data.get("source_name")
|
||||||
|
sink_name = link_data.get("sink_name")
|
||||||
|
|
||||||
|
missing_fields = []
|
||||||
|
if not source_id:
|
||||||
|
missing_fields.append("source_id")
|
||||||
|
if not sink_id:
|
||||||
|
missing_fields.append("sink_id")
|
||||||
|
if not source_name:
|
||||||
|
missing_fields.append("source_name")
|
||||||
|
if not sink_name:
|
||||||
|
missing_fields.append("sink_name")
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
link_id = link_data.get("id", f"index_{idx}")
|
||||||
|
raise AgentJsonValidationError(
|
||||||
|
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
|
||||||
link = Link(
|
link = Link(
|
||||||
id=link_data.get("id", str(uuid.uuid4())),
|
id=link_data.get("id", str(uuid.uuid4())),
|
||||||
source_id=link_data["source_id"],
|
source_id=source_id,
|
||||||
sink_id=link_data["sink_id"],
|
sink_id=sink_id,
|
||||||
source_name=link_data["source_name"],
|
source_name=source_name,
|
||||||
sink_name=link_data["sink_name"],
|
sink_name=sink_name,
|
||||||
is_static=link_data.get("is_static", False),
|
is_static=link_data.get("is_static", False),
|
||||||
)
|
)
|
||||||
links.append(link)
|
links.append(link)
|
||||||
@@ -133,22 +662,40 @@ def _reassign_node_ids(graph: Graph) -> None:
|
|||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
"""
|
"""
|
||||||
# Create mapping from old node IDs to new UUIDs
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
# Reassign node IDs
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
node.id = id_map[node.id]
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
# Update link references to use new node IDs
|
|
||||||
for link in graph.links:
|
for link in graph.links:
|
||||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
link.id = str(uuid.uuid4())
|
||||||
if link.source_id in id_map:
|
if link.source_id in id_map:
|
||||||
link.source_id = id_map[link.source_id]
|
link.source_id = id_map[link.source_id]
|
||||||
if link.sink_id in id_map:
|
if link.sink_id in id_map:
|
||||||
link.sink_id = id_map[link.sink_id]
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -162,33 +709,27 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph_all_versions
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
# For updates, keep the same graph ID but increment version
|
|
||||||
# and reassign node/link IDs to avoid conflicts
|
|
||||||
if graph.id:
|
if graph.id:
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
if existing_versions:
|
if existing_versions:
|
||||||
latest_version = max(v.version for v in existing_versions)
|
latest_version = max(v.version for v in existing_versions)
|
||||||
graph.version = latest_version + 1
|
graph.version = latest_version + 1
|
||||||
# Reassign node IDs (but keep graph ID the same)
|
|
||||||
_reassign_node_ids(graph)
|
_reassign_node_ids(graph)
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
else:
|
else:
|
||||||
# For new agents, always generate a fresh UUID to avoid collisions
|
|
||||||
graph.id = str(uuid.uuid4())
|
graph.id = str(uuid.uuid4())
|
||||||
graph.version = 1
|
graph.version = 1
|
||||||
# Reassign all node IDs as well
|
|
||||||
_reassign_node_ids(graph)
|
_reassign_node_ids(graph)
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
# Save to database
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
# Add to user's library (or update existing library agent)
|
|
||||||
library_agents = await library_db.create_library_agent(
|
library_agents = await library_db.create_library_agent(
|
||||||
graph=created_graph,
|
graph=created_graph,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -200,25 +741,31 @@ async def save_agent_to_library(
|
|||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
async def get_agent_as_json(
|
||||||
graph_id: str, user_id: str | None
|
agent_id: str, user_id: str | None
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Graph ID or library agent ID
|
agent_id: Graph ID or library agent ID
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
from backend.data.graph import get_graph
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to get the graph (version=None gets the active version)
|
|
||||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
|
||||||
if not graph:
|
if not graph:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert to JSON format
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -256,7 +803,9 @@ async def get_agent_as_json(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[AgentSummary] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -268,6 +817,7 @@ async def generate_agent_patch(
|
|||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
@@ -278,4 +828,6 @@ async def generate_agent_patch(
|
|||||||
"""
|
"""
|
||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(update_request, current_agent)
|
return await generate_agent_patch_external(
|
||||||
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,11 +1,43 @@
|
|||||||
"""Error handling utilities for agent generator."""
|
"""Error handling utilities for agent generator."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_error_details(details: str) -> str:
|
||||||
|
"""Sanitize error details to remove sensitive information.
|
||||||
|
|
||||||
|
Strips common patterns that could expose internal system info:
|
||||||
|
- File paths (Unix and Windows)
|
||||||
|
- Database connection strings
|
||||||
|
- URLs with credentials
|
||||||
|
- Stack trace internals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
details: Raw error details string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized error details safe for user display
|
||||||
|
"""
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||||
|
)
|
||||||
|
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||||
|
sanitized = re.sub(
|
||||||
|
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||||
|
)
|
||||||
|
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||||
|
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||||
|
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||||
|
|
||||||
|
return sanitized.strip()
|
||||||
|
|
||||||
|
|
||||||
def get_user_message_for_error(
|
def get_user_message_for_error(
|
||||||
error_type: str,
|
error_type: str,
|
||||||
operation: str = "process the request",
|
operation: str = "process the request",
|
||||||
llm_parse_message: str | None = None,
|
llm_parse_message: str | None = None,
|
||||||
validation_message: str | None = None,
|
validation_message: str | None = None,
|
||||||
|
error_details: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get a user-friendly error message based on error type.
|
"""Get a user-friendly error message based on error type.
|
||||||
|
|
||||||
@@ -19,25 +51,45 @@ def get_user_message_for_error(
|
|||||||
message (e.g., "analyze the goal", "generate the agent")
|
message (e.g., "analyze the goal", "generate the agent")
|
||||||
llm_parse_message: Custom message for llm_parse_error type
|
llm_parse_message: Custom message for llm_parse_error type
|
||||||
validation_message: Custom message for validation_error type
|
validation_message: Custom message for validation_error type
|
||||||
|
error_details: Optional additional details about the error
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
User-friendly error message suitable for display to the user
|
User-friendly error message suitable for display to the user
|
||||||
"""
|
"""
|
||||||
|
base_message = ""
|
||||||
|
|
||||||
if error_type == "llm_parse_error":
|
if error_type == "llm_parse_error":
|
||||||
return (
|
base_message = (
|
||||||
llm_parse_message
|
llm_parse_message
|
||||||
or "The AI had trouble processing this request. Please try again."
|
or "The AI had trouble processing this request. Please try again."
|
||||||
)
|
)
|
||||||
elif error_type == "validation_error":
|
elif error_type == "validation_error":
|
||||||
return (
|
base_message = (
|
||||||
validation_message
|
validation_message
|
||||||
or "The request failed validation. Please try rephrasing."
|
or "The generated agent failed validation. "
|
||||||
|
"This usually happens when the agent structure doesn't match "
|
||||||
|
"what the platform expects. Please try simplifying your goal "
|
||||||
|
"or breaking it into smaller parts."
|
||||||
)
|
)
|
||||||
elif error_type == "patch_error":
|
elif error_type == "patch_error":
|
||||||
return "Failed to apply the changes. Please try a different approach."
|
base_message = (
|
||||||
|
"Failed to apply the changes. The modification couldn't be "
|
||||||
|
"validated. Please try a different approach or simplify the change."
|
||||||
|
)
|
||||||
elif error_type in ("timeout", "llm_timeout"):
|
elif error_type in ("timeout", "llm_timeout"):
|
||||||
return "The request took too long. Please try again."
|
base_message = (
|
||||||
|
"The request took too long to process. This can happen with "
|
||||||
|
"complex agents. Please try again or simplify your goal."
|
||||||
|
)
|
||||||
elif error_type in ("rate_limit", "llm_rate_limit"):
|
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||||
return "The service is currently busy. Please try again in a moment."
|
base_message = "The service is currently busy. Please try again in a moment."
|
||||||
else:
|
else:
|
||||||
return f"Failed to {operation}. Please try again."
|
base_message = f"Failed to {operation}. Please try again."
|
||||||
|
|
||||||
|
if error_details:
|
||||||
|
details = _sanitize_error_details(error_details)
|
||||||
|
if len(details) > 200:
|
||||||
|
details = details[:200] + "..."
|
||||||
|
base_message += f"\n\nTechnical details: {details}"
|
||||||
|
|
||||||
|
return base_message
|
||||||
|
|||||||
@@ -117,13 +117,16 @@ def _get_client() -> httpx.AsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
async def decompose_goal_external(
|
async def decompose_goal_external(
|
||||||
description: str, context: str = ""
|
description: str,
|
||||||
|
context: str = "",
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to decompose a goal.
|
"""Call the external service to decompose a goal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Natural language goal description
|
description: Natural language goal description
|
||||||
context: Additional context (e.g., answers to previous questions)
|
context: Additional context (e.g., answers to previous questions)
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with either:
|
Dict with either:
|
||||||
@@ -136,11 +139,12 @@ async def decompose_goal_external(
|
|||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build the request payload
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
|
||||||
if context:
|
if context:
|
||||||
# The external service uses user_instruction for additional context
|
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||||
payload["user_instruction"] = context
|
|
||||||
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/decompose-description", json=payload)
|
response = await client.post("/api/decompose-description", json=payload)
|
||||||
@@ -207,21 +211,25 @@ async def decompose_goal_external(
|
|||||||
|
|
||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
"/api/generate-agent", json={"instructions": instructions}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -229,8 +237,7 @@ async def generate_agent_external(
|
|||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||||
error_type = data.get("error_type", "unknown")
|
error_type = data.get("error_type", "unknown")
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Agent Generator generation failed: {error_msg} "
|
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
)
|
||||||
return _create_error_response(error_msg, error_type)
|
return _create_error_response(error_msg, error_type)
|
||||||
|
|
||||||
@@ -251,27 +258,31 @@ async def generate_agent_external(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch_external(
|
async def generate_agent_patch_external(
|
||||||
update_request: str, current_agent: dict[str, Any]
|
update_request: str,
|
||||||
|
current_agent: dict[str, Any],
|
||||||
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
|
library_agents: User's library agents available for sub-agent composition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"update_request": update_request,
|
||||||
|
"current_agent_json": current_agent,
|
||||||
|
}
|
||||||
|
if library_agents:
|
||||||
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post(
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
"/api/update-agent",
|
|
||||||
json={
|
|
||||||
"update_request": update_request,
|
|
||||||
"current_agent_json": current_agent,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
@@ -19,6 +20,85 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SearchSource = Literal["marketplace", "library"]
|
SearchSource = Literal["marketplace", "library"]
|
||||||
|
|
||||||
|
_UUID_PATTERN = re.compile(
|
||||||
|
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4."""
|
||||||
|
return bool(_UUID_PATTERN.match(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||||
|
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||||
|
|
||||||
|
Tries multiple lookup strategies:
|
||||||
|
1. First by graph_id (AgentGraph primary key)
|
||||||
|
2. Then by library agent ID (LibraryAgent primary key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
if agent:
|
||||||
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
|
return AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||||
|
except DatabaseError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def search_agents(
|
async def search_agents(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -69,29 +149,37 @@ async def search_agents(
|
|||||||
is_featured=False,
|
is_featured=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # library
|
else:
|
||||||
logger.info(f"Searching user library for: {query}")
|
if _is_uuid(query):
|
||||||
results = await library_db.list_library_agents(
|
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||||
search_term=query,
|
if agent:
|
||||||
page_size=10,
|
agents.append(agent)
|
||||||
)
|
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||||
for agent in results.agents:
|
|
||||||
agents.append(
|
if not agents:
|
||||||
AgentInfo(
|
logger.info(f"Searching user library for: {query}")
|
||||||
id=agent.id,
|
results = await library_db.list_library_agents(
|
||||||
name=agent.name,
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
description=agent.description or "",
|
search_term=query,
|
||||||
source="library",
|
page_size=10,
|
||||||
in_library=True,
|
|
||||||
creator=agent.creator_name,
|
|
||||||
status=agent.status.value,
|
|
||||||
can_access_graph=agent.can_access_graph,
|
|
||||||
has_external_trigger=agent.has_external_trigger,
|
|
||||||
new_output=agent.new_output,
|
|
||||||
graph_id=agent.graph_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for agent in results.agents:
|
||||||
|
agents.append(
|
||||||
|
AgentInfo(
|
||||||
|
id=agent.id,
|
||||||
|
name=agent.name,
|
||||||
|
description=agent.description or "",
|
||||||
|
source="library",
|
||||||
|
in_library=True,
|
||||||
|
creator=agent.creator_name,
|
||||||
|
status=agent.status.value,
|
||||||
|
can_access_graph=agent.can_access_graph,
|
||||||
|
has_external_trigger=agent.has_external_trigger,
|
||||||
|
new_output=agent.new_output,
|
||||||
|
graph_id=agent.graph_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info(f"Found {len(agents)} agents in {source}")
|
logger.info(f"Found {len(agents)} agents in {source}")
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from backend.api.features.chat.model import ChatSession
|
|||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
|
enrich_library_agents_from_steps,
|
||||||
generate_agent,
|
generate_agent,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
get_user_message_for_error,
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
@@ -103,9 +105,24 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 1: Decompose goal into steps
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=description,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(description, context)
|
decomposition_result = await decompose_goal(
|
||||||
|
description, context, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -124,7 +141,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if decomposition_result.get("type") == "error":
|
if decomposition_result.get("type") == "error":
|
||||||
error_msg = decomposition_result.get("error", "Unknown error")
|
error_msg = decomposition_result.get("error", "Unknown error")
|
||||||
error_type = decomposition_result.get("error_type", "unknown")
|
error_type = decomposition_result.get("error_type", "unknown")
|
||||||
@@ -144,7 +160,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
|
||||||
if decomposition_result.get("type") == "clarifying_questions":
|
if decomposition_result.get("type") == "clarifying_questions":
|
||||||
questions = decomposition_result.get("questions", [])
|
questions = decomposition_result.get("questions", [])
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
@@ -163,7 +178,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for unachievable/vague goals
|
|
||||||
if decomposition_result.get("type") == "unachievable_goal":
|
if decomposition_result.get("type") == "unachievable_goal":
|
||||||
suggested = decomposition_result.get("suggested_goal", "")
|
suggested = decomposition_result.get("suggested_goal", "")
|
||||||
reason = decomposition_result.get("reason", "")
|
reason = decomposition_result.get("reason", "")
|
||||||
@@ -190,9 +204,22 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
if user_id and library_agents is not None:
|
||||||
|
try:
|
||||||
|
library_agents = await enrich_library_agents_from_steps(
|
||||||
|
user_id=user_id,
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=library_agents,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(decomposition_result)
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -211,7 +238,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||||
error_msg = agent_json.get("error", "Unknown error")
|
error_msg = agent_json.get("error", "Unknown error")
|
||||||
error_type = agent_json.get("error_type", "unknown")
|
error_type = agent_json.get("error_type", "unknown")
|
||||||
@@ -219,7 +245,12 @@ class CreateAgentTool(BaseTool):
|
|||||||
error_type,
|
error_type,
|
||||||
operation="generate the agent",
|
operation="generate the agent",
|
||||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||||
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
|
validation_message=(
|
||||||
|
"I wasn't able to create a valid agent for this request. "
|
||||||
|
"The generated workflow had some structural issues. "
|
||||||
|
"Please try simplifying your goal or breaking it into smaller steps."
|
||||||
|
),
|
||||||
|
error_details=error_msg,
|
||||||
)
|
)
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
@@ -237,7 +268,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
# Step 3: Preview or save
|
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -252,7 +282,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to library
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="You must be logged in to save agents.",
|
message="You must be logged in to save agents.",
|
||||||
@@ -270,7 +299,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .agent_generator import (
|
|||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
generate_agent_patch,
|
generate_agent_patch,
|
||||||
get_agent_as_json,
|
get_agent_as_json,
|
||||||
|
get_all_relevant_agents_for_generation,
|
||||||
get_user_message_for_error,
|
get_user_message_for_error,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
)
|
)
|
||||||
@@ -117,7 +118,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 1: Fetch current agent
|
|
||||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||||
|
|
||||||
if current_agent is None:
|
if current_agent is None:
|
||||||
@@ -127,14 +127,30 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the update request with context
|
library_agents = None
|
||||||
|
if user_id:
|
||||||
|
try:
|
||||||
|
graph_id = current_agent.get("id")
|
||||||
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=changes,
|
||||||
|
exclude_graph_id=graph_id,
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
update_request = changes
|
update_request = changes
|
||||||
if context:
|
if context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||||
|
|
||||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(update_request, current_agent)
|
result = await generate_agent_patch(
|
||||||
|
update_request, current_agent, library_agents
|
||||||
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -153,7 +169,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
@@ -162,6 +177,7 @@ class EditAgentTool(BaseTool):
|
|||||||
operation="generate the changes",
|
operation="generate the changes",
|
||||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||||
|
error_details=error_msg,
|
||||||
)
|
)
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=user_message,
|
message=user_message,
|
||||||
@@ -175,7 +191,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if LLM returned clarifying questions
|
|
||||||
if result.get("type") == "clarifying_questions":
|
if result.get("type") == "clarifying_questions":
|
||||||
questions = result.get("questions", [])
|
questions = result.get("questions", [])
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
@@ -194,7 +209,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Result is the updated agent JSON
|
|
||||||
updated_agent = result
|
updated_agent = result
|
||||||
|
|
||||||
agent_name = updated_agent.get("name", "Updated Agent")
|
agent_name = updated_agent.get("name", "Updated Agent")
|
||||||
@@ -202,7 +216,6 @@ class EditAgentTool(BaseTool):
|
|||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
# Step 3: Preview or save
|
|
||||||
if not save:
|
if not save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -218,7 +231,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save to library (creates a new version)
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="You must be logged in to save agents.",
|
message="You must be logged in to save agents.",
|
||||||
@@ -236,7 +248,7 @@ class EditAgentTool(BaseTool):
|
|||||||
agent_id=created_graph.id,
|
agent_id=created_graph.id,
|
||||||
agent_name=created_graph.name,
|
agent_name=created_graph.name,
|
||||||
library_agent_id=library_agent.id,
|
library_agent_id=library_agent.id,
|
||||||
library_agent_link=f"/library/{library_agent.id}",
|
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(
|
||||||
|
message: str, session_id: str | None, **kwargs: Any
|
||||||
|
) -> ErrorResponse:
|
||||||
|
"""Create standardized error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message to display
|
||||||
|
session_id: Current session ID
|
||||||
|
**kwargs: Additional fields to pass to ErrorResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ErrorResponse with the given message and session_id
|
||||||
|
"""
|
||||||
|
return ErrorResponse(message=message, session_id=session_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_schema: JSON schema dict with 'properties' and 'required'
|
||||||
|
exclude_fields: Set of field names to exclude (e.g., credential fields)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with field info (name, title, type, description, required, default)
|
||||||
|
"""
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
|
||||||
|
"""Format input fields as a readable markdown list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: List of input dicts from get_inputs_from_schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string listing the inputs
|
||||||
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return "No inputs required."
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for inp in inputs:
|
||||||
|
required_marker = " (required)" if inp.get("required") else ""
|
||||||
|
default = inp.get("default")
|
||||||
|
default_info = f" [default: {default}]" if default is not None else ""
|
||||||
|
description = inp.get("description", "")
|
||||||
|
desc_info = f" - {description}" if description else ""
|
||||||
|
|
||||||
|
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -354,19 +355,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""Extract inputs list from schema."""
|
"""Extract inputs list from schema."""
|
||||||
inputs_list = []
|
return get_inputs_from_schema(input_schema)
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ from typing import Any
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -22,7 +23,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,6 +75,22 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_credentials_requirements(
|
||||||
|
self,
|
||||||
|
block: Any,
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""
|
||||||
|
Get credential requirements from block's input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: Block to get credentials for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to CredentialsFieldInfo
|
||||||
|
"""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
return credentials_fields_info if credentials_fields_info else {}
|
||||||
|
|
||||||
async def _check_block_credentials(
|
async def _check_block_credentials(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -82,53 +102,12 @@ class RunBlockTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
requirements = self._get_credentials_requirements(block)
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
if not requirements:
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
return {}, []
|
||||||
|
|
||||||
if not credentials_fields_info:
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
# field_info.provider is a frozenset of acceptable providers
|
|
||||||
# field_info.supported_types is a frozenset of acceptable types
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in field_info.provider
|
|
||||||
and cred.type in field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
@@ -320,27 +299,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -225,6 +225,127 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list:
|
||||||
|
"""
|
||||||
|
Get all available credentials for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user's credentials
|
||||||
|
"""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list,
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find a credential that matches the required provider, type, and scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_creds: List of user's available credentials
|
||||||
|
field_info: CredentialsFieldInfo with provider, type, and scope requirements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching credential or None
|
||||||
|
"""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if not _credential_has_required_scopes(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""
|
||||||
|
Create a CredentialsMetaInput from a matched credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matching_cred: The matched credential object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CredentialsMetaInput instance
|
||||||
|
"""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials dict, missing_credentials list]
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -242,9 +363,6 @@ async def match_user_credentials_to_graph(
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials dict, missing_credential_descriptions list]
|
tuple[matched_credentials dict, missing_credential_descriptions list]
|
||||||
"""
|
"""
|
||||||
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_creds: list[str] = []
|
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -252,61 +370,52 @@ async def match_user_credentials_to_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not aggregated_creds:
|
if not aggregated_creds:
|
||||||
return graph_credentials_inputs, missing_creds
|
return {}, []
|
||||||
|
|
||||||
# Get all available credentials for the user
|
# Convert aggregated format to simple requirements dict
|
||||||
creds_manager = IntegrationCredentialsManager()
|
requirements = {
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
field_name: field_info
|
||||||
|
for field_name, (field_info, _node_fields) in aggregated_creds.items()
|
||||||
|
}
|
||||||
|
|
||||||
# For each required credential field, find a matching user credential
|
# Use shared matching logic
|
||||||
# field_info.provider is a frozenset because aggregate_credentials_inputs()
|
matched, missing_list = await match_credentials_to_requirements(
|
||||||
# combines requirements from multiple nodes. A credential matches if its
|
user_id, requirements
|
||||||
# provider is in the set of acceptable providers.
|
|
||||||
for credential_field_name, (
|
|
||||||
credential_requirements,
|
|
||||||
_node_fields,
|
|
||||||
) in aggregated_creds.items():
|
|
||||||
# Find first matching credential by provider and type
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in credential_requirements.provider
|
|
||||||
and cred.type in credential_requirements.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
try:
|
|
||||||
graph_credentials_inputs[credential_field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (validation failed: {e})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} "
|
|
||||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
|
||||||
f"type in {list(credential_requirements.supported_types)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_credentials_inputs, missing_creds
|
# Convert missing list to string descriptions for backward compatibility
|
||||||
|
missing_descriptions = [
|
||||||
|
f"{cred.id} (requires provider={cred.provider}, type={cred.type})"
|
||||||
|
for cred in missing_list
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched"
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_has_required_scopes(
|
||||||
|
credential: Credentials,
|
||||||
|
requirements: CredentialsFieldInfo,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If no scopes are required, any credential matches
|
||||||
|
if not requirements.required_scopes:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check that credential scopes are a superset of required scopes
|
||||||
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ async def list_library_agents(
|
|||||||
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
|
include_executions: bool = False,
|
||||||
) -> library_model.LibraryAgentResponse:
|
) -> library_model.LibraryAgentResponse:
|
||||||
"""
|
"""
|
||||||
Retrieves a paginated list of LibraryAgent records for a given user.
|
Retrieves a paginated list of LibraryAgent records for a given user.
|
||||||
@@ -49,6 +50,9 @@ async def list_library_agents(
|
|||||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||||
page: Current page (1-indexed).
|
page: Current page (1-indexed).
|
||||||
page_size: Number of items per page.
|
page_size: Number of items per page.
|
||||||
|
include_executions: Whether to include execution data for status calculation.
|
||||||
|
Defaults to False for performance (UI fetches status separately).
|
||||||
|
Set to True when accurate status/metrics are needed (e.g., agent generator).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LibraryAgentResponse containing the list of agents and pagination details.
|
A LibraryAgentResponse containing the list of agents and pagination details.
|
||||||
@@ -76,7 +80,6 @@ async def list_library_agents(
|
|||||||
"isArchived": False,
|
"isArchived": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build search filter if applicable
|
|
||||||
if search_term:
|
if search_term:
|
||||||
where_clause["OR"] = [
|
where_clause["OR"] = [
|
||||||
{
|
{
|
||||||
@@ -93,7 +96,6 @@ async def list_library_agents(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Determine sorting
|
|
||||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||||
|
|
||||||
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
||||||
@@ -105,7 +107,7 @@ async def list_library_agents(
|
|||||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
include=library_agent_include(
|
include=library_agent_include(
|
||||||
user_id, include_nodes=False, include_executions=False
|
user_id, include_nodes=False, include_executions=include_executions
|
||||||
),
|
),
|
||||||
order=order_by,
|
order=order_by,
|
||||||
skip=(page - 1) * page_size,
|
skip=(page - 1) * page_size,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import pydantic
|
|||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||||
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -16,10 +17,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class LibraryAgentStatus(str, Enum):
|
class LibraryAgentStatus(str, Enum):
|
||||||
COMPLETED = "COMPLETED" # All runs completed
|
COMPLETED = "COMPLETED"
|
||||||
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
|
HEALTHY = "HEALTHY"
|
||||||
WAITING = "WAITING" # Agent is queued or waiting to start
|
WAITING = "WAITING"
|
||||||
ERROR = "ERROR" # Agent is in an error state
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||||
@@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel):
|
|||||||
creator: MarketplaceListingCreator
|
creator: MarketplaceListingCreator
|
||||||
|
|
||||||
|
|
||||||
|
class RecentExecution(pydantic.BaseModel):
|
||||||
|
"""Summary of a recent execution for quality assessment.
|
||||||
|
|
||||||
|
Used by the LLM to understand the agent's recent performance with specific examples
|
||||||
|
rather than just aggregate statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
correctness_score: float | None = None
|
||||||
|
activity_summary: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_settings(settings: dict | str | None) -> GraphSettings:
|
||||||
|
"""Parse settings from database, handling both dict and string formats."""
|
||||||
|
if settings is None:
|
||||||
|
return GraphSettings()
|
||||||
|
try:
|
||||||
|
if isinstance(settings, str):
|
||||||
|
settings = json_loads(settings)
|
||||||
|
return GraphSettings.model_validate(settings)
|
||||||
|
except Exception:
|
||||||
|
return GraphSettings()
|
||||||
|
|
||||||
|
|
||||||
class LibraryAgent(pydantic.BaseModel):
|
class LibraryAgent(pydantic.BaseModel):
|
||||||
"""
|
"""
|
||||||
Represents an agent in the library, including metadata for display and
|
Represents an agent in the library, including metadata for display and
|
||||||
@@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
graph_version: int
|
graph_version: int
|
||||||
owner_user_id: str # ID of user who owns/created this agent graph
|
owner_user_id: str
|
||||||
|
|
||||||
image_url: str | None
|
image_url: str | None
|
||||||
|
|
||||||
@@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
||||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
input_schema: dict[str, Any]
|
||||||
output_schema: dict[str, Any]
|
output_schema: dict[str, Any]
|
||||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||||
description="Input schema for credentials required by the agent",
|
description="Input schema for credentials required by the agent",
|
||||||
@@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
)
|
)
|
||||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||||
|
|
||||||
# Indicates whether there's a new output (based on recent runs)
|
|
||||||
new_output: bool
|
new_output: bool
|
||||||
|
execution_count: int = 0
|
||||||
# Whether the user can access the underlying graph
|
success_rate: float | None = None
|
||||||
|
avg_correctness_score: float | None = None
|
||||||
|
recent_executions: list[RecentExecution] = pydantic.Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of recent executions with status, score, and summary",
|
||||||
|
)
|
||||||
can_access_graph: bool
|
can_access_graph: bool
|
||||||
|
|
||||||
# Indicates if this agent is the latest version
|
|
||||||
is_latest_version: bool
|
is_latest_version: bool
|
||||||
|
|
||||||
# Whether the agent is marked as favorite by the user
|
|
||||||
is_favorite: bool
|
is_favorite: bool
|
||||||
|
|
||||||
# Recommended schedule cron (from marketplace agents)
|
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
|
|
||||||
# User-specific settings for this library agent
|
|
||||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||||
|
|
||||||
# Marketplace listing information if the agent has been published
|
|
||||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
agent_updated_at = agent.AgentGraph.updatedAt
|
agent_updated_at = agent.AgentGraph.updatedAt
|
||||||
lib_agent_updated_at = agent.updatedAt
|
lib_agent_updated_at = agent.updatedAt
|
||||||
|
|
||||||
# Compute updated_at as the latest between library agent and graph
|
|
||||||
updated_at = (
|
updated_at = (
|
||||||
max(agent_updated_at, lib_agent_updated_at)
|
max(agent_updated_at, lib_agent_updated_at)
|
||||||
if agent_updated_at
|
if agent_updated_at
|
||||||
@@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
creator_name = agent.Creator.name or "Unknown"
|
creator_name = agent.Creator.name or "Unknown"
|
||||||
creator_image_url = agent.Creator.avatarUrl or ""
|
creator_image_url = agent.Creator.avatarUrl or ""
|
||||||
|
|
||||||
# Logic to calculate status and new_output
|
|
||||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||||
days=7
|
days=7
|
||||||
)
|
)
|
||||||
@@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
status = status_result.status
|
status = status_result.status
|
||||||
new_output = status_result.new_output
|
new_output = status_result.new_output
|
||||||
|
|
||||||
# Check if user can access the graph
|
execution_count = len(executions)
|
||||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
success_rate: float | None = None
|
||||||
|
avg_correctness_score: float | None = None
|
||||||
|
if execution_count > 0:
|
||||||
|
success_count = sum(
|
||||||
|
1
|
||||||
|
for e in executions
|
||||||
|
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
|
||||||
|
)
|
||||||
|
success_rate = (success_count / execution_count) * 100
|
||||||
|
|
||||||
# Hard-coded to True until a method to check is implemented
|
correctness_scores = []
|
||||||
|
for e in executions:
|
||||||
|
if e.stats and isinstance(e.stats, dict):
|
||||||
|
score = e.stats.get("correctness_score")
|
||||||
|
if score is not None and isinstance(score, (int, float)):
|
||||||
|
correctness_scores.append(float(score))
|
||||||
|
if correctness_scores:
|
||||||
|
avg_correctness_score = sum(correctness_scores) / len(
|
||||||
|
correctness_scores
|
||||||
|
)
|
||||||
|
|
||||||
|
recent_executions: list[RecentExecution] = []
|
||||||
|
for e in executions:
|
||||||
|
exec_score: float | None = None
|
||||||
|
exec_summary: str | None = None
|
||||||
|
if e.stats and isinstance(e.stats, dict):
|
||||||
|
score = e.stats.get("correctness_score")
|
||||||
|
if score is not None and isinstance(score, (int, float)):
|
||||||
|
exec_score = float(score)
|
||||||
|
summary = e.stats.get("activity_status")
|
||||||
|
if summary is not None and isinstance(summary, str):
|
||||||
|
exec_summary = summary
|
||||||
|
exec_status = (
|
||||||
|
e.executionStatus.value
|
||||||
|
if hasattr(e.executionStatus, "value")
|
||||||
|
else str(e.executionStatus)
|
||||||
|
)
|
||||||
|
recent_executions.append(
|
||||||
|
RecentExecution(
|
||||||
|
status=exec_status,
|
||||||
|
correctness_score=exec_score,
|
||||||
|
activity_summary=exec_summary,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||||
is_latest_version = True
|
is_latest_version = True
|
||||||
|
|
||||||
# Build marketplace_listing if available
|
|
||||||
marketplace_listing_data = None
|
marketplace_listing_data = None
|
||||||
if store_listing and store_listing.ActiveVersion and profile:
|
if store_listing and store_listing.ActiveVersion and profile:
|
||||||
creator_data = MarketplaceListingCreator(
|
creator_data = MarketplaceListingCreator(
|
||||||
@@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel):
|
|||||||
has_sensitive_action=graph.has_sensitive_action,
|
has_sensitive_action=graph.has_sensitive_action,
|
||||||
trigger_setup_info=graph.trigger_setup_info,
|
trigger_setup_info=graph.trigger_setup_info,
|
||||||
new_output=new_output,
|
new_output=new_output,
|
||||||
|
execution_count=execution_count,
|
||||||
|
success_rate=success_rate,
|
||||||
|
avg_correctness_score=avg_correctness_score,
|
||||||
|
recent_executions=recent_executions,
|
||||||
can_access_graph=can_access_graph,
|
can_access_graph=can_access_graph,
|
||||||
is_latest_version=is_latest_version,
|
is_latest_version=is_latest_version,
|
||||||
is_favorite=agent.isFavorite,
|
is_favorite=agent.isFavorite,
|
||||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||||
settings=GraphSettings.model_validate(agent.settings),
|
settings=_parse_settings(agent.settings),
|
||||||
marketplace_listing=marketplace_listing_data,
|
marketplace_listing=marketplace_listing_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -220,18 +283,15 @@ def _calculate_agent_status(
|
|||||||
if not executions:
|
if not executions:
|
||||||
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
||||||
|
|
||||||
# Track how many times each execution status appears
|
|
||||||
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
||||||
new_output = False
|
new_output = False
|
||||||
|
|
||||||
for execution in executions:
|
for execution in executions:
|
||||||
# Check if there's a completed run more recent than `recent_threshold`
|
|
||||||
if execution.createdAt >= recent_threshold:
|
if execution.createdAt >= recent_threshold:
|
||||||
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
||||||
new_output = True
|
new_output = True
|
||||||
status_counts[execution.executionStatus] += 1
|
status_counts[execution.executionStatus] += 1
|
||||||
|
|
||||||
# Determine the final status based on counts
|
|
||||||
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
||||||
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
||||||
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ async def get_store_agents(
|
|||||||
description=agent["description"],
|
description=agent["description"],
|
||||||
runs=agent["runs"],
|
runs=agent["runs"],
|
||||||
rating=agent["rating"],
|
rating=agent["rating"],
|
||||||
|
agent_graph_id=agent.get("agentGraphId", ""),
|
||||||
)
|
)
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -170,6 +171,7 @@ async def get_store_agents(
|
|||||||
description=agent.description,
|
description=agent.description,
|
||||||
runs=agent.runs,
|
runs=agent.runs,
|
||||||
rating=agent.rating,
|
rating=agent.rating,
|
||||||
|
agent_graph_id=agent.agentGraphId,
|
||||||
)
|
)
|
||||||
# Add to the list only if creation was successful
|
# Add to the list only if creation was successful
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
|
|||||||
@@ -600,6 +600,7 @@ async def hybrid_search(
|
|||||||
sa.featured,
|
sa.featured,
|
||||||
sa.is_available,
|
sa.is_available,
|
||||||
sa.updated_at,
|
sa.updated_at,
|
||||||
|
sa."agentGraphId",
|
||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
@@ -659,6 +660,7 @@ async def hybrid_search(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
|
"agentGraphId",
|
||||||
searchable_text,
|
searchable_text,
|
||||||
semantic_score,
|
semantic_score,
|
||||||
lexical_score,
|
lexical_score,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class StoreAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
|
agent_graph_id: str
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentsResponse(pydantic.BaseModel):
|
class StoreAgentsResponse(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -26,11 +26,13 @@ def test_store_agent():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
assert agent.slug == "test-agent"
|
assert agent.slug == "test-agent"
|
||||||
assert agent.agent_name == "Test Agent"
|
assert agent.agent_name == "Test Agent"
|
||||||
assert agent.runs == 50
|
assert agent.runs == 50
|
||||||
assert agent.rating == 4.5
|
assert agent.rating == 4.5
|
||||||
|
assert agent.agent_graph_id == "test-graph-id"
|
||||||
|
|
||||||
|
|
||||||
def test_store_agents_response():
|
def test_store_agents_response():
|
||||||
@@ -46,6 +48,7 @@ def test_store_agents_response():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def test_get_agents_featured(
|
|||||||
description="Featured agent description",
|
description="Featured agent description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-1",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -127,6 +128,7 @@ def test_get_agents_by_creator(
|
|||||||
description="Creator agent description",
|
description="Creator agent description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
|
agent_graph_id="test-graph-2",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -172,6 +174,7 @@ def test_get_agents_sorted(
|
|||||||
description="Top agent description",
|
description="Top agent description",
|
||||||
runs=1000,
|
runs=1000,
|
||||||
rating=5.0,
|
rating=5.0,
|
||||||
|
agent_graph_id="test-graph-3",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -217,6 +220,7 @@ def test_get_agents_search(
|
|||||||
description="Specific search term description",
|
description="Specific search term description",
|
||||||
runs=75,
|
runs=75,
|
||||||
rating=4.2,
|
rating=4.2,
|
||||||
|
agent_graph_id="test-graph-search",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -262,6 +266,7 @@ def test_get_agents_category(
|
|||||||
description="Category agent description",
|
description="Category agent description",
|
||||||
runs=60,
|
runs=60,
|
||||||
rating=4.1,
|
rating=4.1,
|
||||||
|
agent_graph_id="test-graph-category",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -306,6 +311,7 @@ def test_get_agents_pagination(
|
|||||||
description=f"Agent {i} description",
|
description=f"Agent {i} description",
|
||||||
runs=i * 10,
|
runs=i * 10,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
|
agent_graph_id="test-graph-2",
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class TestCacheDeletion:
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
|
agent_graph_id="test-graph-id",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=Pagination(
|
pagination=Pagination(
|
||||||
|
|||||||
@@ -66,18 +66,24 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
execution_bus = AsyncRedisExecutionEventBus()
|
execution_bus = AsyncRedisExecutionEventBus()
|
||||||
notification_bus = AsyncRedisNotificationEventBus()
|
notification_bus = AsyncRedisNotificationEventBus()
|
||||||
|
|
||||||
async def execution_worker():
|
try:
|
||||||
async for event in execution_bus.listen("*"):
|
|
||||||
await manager.send_execution_update(event)
|
|
||||||
|
|
||||||
async def notification_worker():
|
async def execution_worker():
|
||||||
async for notification in notification_bus.listen("*"):
|
async for event in execution_bus.listen("*"):
|
||||||
await manager.send_notification(
|
await manager.send_execution_update(event)
|
||||||
user_id=notification.user_id,
|
|
||||||
payload=notification.payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
async def notification_worker():
|
||||||
|
async for notification in notification_bus.listen("*"):
|
||||||
|
await manager.send_notification(
|
||||||
|
user_id=notification.user_id,
|
||||||
|
payload=notification.payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(execution_worker(), notification_worker())
|
||||||
|
finally:
|
||||||
|
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||||
|
await execution_bus.close()
|
||||||
|
await notification_bus.close()
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
|||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
@@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -280,9 +279,6 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||||
), # claude-haiku-4-5-20251001
|
), # claude-haiku-4-5-20251001
|
||||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
|
||||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
|
||||||
), # claude-3-7-sonnet-20250219
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||||
), # claude-3-haiku-20240307
|
), # claude-3-haiku-20240307
|
||||||
@@ -638,11 +634,18 @@ async def llm_call(
|
|||||||
context_window = llm_model.context_window
|
context_window = llm_model.context_window
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
prompt = compress_prompt(
|
result = await compress_context(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=llm_model.context_window // 2,
|
target_tokens=llm_model.context_window // 2,
|
||||||
lossy_ok=True,
|
client=None, # Truncation-only, no LLM summarization
|
||||||
|
reserve=0, # Caller handles response token budget separately
|
||||||
)
|
)
|
||||||
|
if result.error:
|
||||||
|
logger.warning(
|
||||||
|
f"Prompt compression did not meet target: {result.error}. "
|
||||||
|
f"Proceeding with {result.token_count} tokens."
|
||||||
|
)
|
||||||
|
prompt = result.messages
|
||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
|||||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||||
|
|
||||||
# Anthropic
|
# Anthropic
|
||||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -230,7 +230,7 @@ class StagehandActBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
@@ -330,7 +330,7 @@ class StagehandExtractBlock(Block):
|
|||||||
model: StagehandRecommendedLlmModel = SchemaField(
|
model: StagehandRecommendedLlmModel = SchemaField(
|
||||||
title="LLM Model",
|
title="LLM Model",
|
||||||
description="LLM to use for Stagehand (provider is inferred)",
|
description="LLM to use for Stagehand (provider is inferred)",
|
||||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||||
advanced=False,
|
advanced=False,
|
||||||
)
|
)
|
||||||
model_credentials: AICredentials = AICredentialsField()
|
model_credentials: AICredentials = AICredentialsField()
|
||||||
|
|||||||
@@ -81,7 +81,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
|
||||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||||
|
|||||||
@@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||||
|
def __init__(self):
|
||||||
|
self._pubsub: AsyncPubSub | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
async def connection(self) -> redis.AsyncRedis:
|
async def connection(self) -> redis.AsyncRedis:
|
||||||
return await redis.get_redis_async()
|
return await redis.get_redis_async()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the PubSub connection if it exists."""
|
||||||
|
if self._pubsub is not None:
|
||||||
|
try:
|
||||||
|
await self._pubsub.close()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||||
|
finally:
|
||||||
|
self._pubsub = None
|
||||||
|
|
||||||
async def publish_event(self, event: M, channel_key: str):
|
async def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
"""
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
Publish an event to Redis. Gracefully handles connection failures
|
||||||
@@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
await self.connection, channel_key
|
await self.connection, channel_key
|
||||||
)
|
)
|
||||||
assert isinstance(pubsub, AsyncPubSub)
|
assert isinstance(pubsub, AsyncPubSub)
|
||||||
|
self._pubsub = pubsub
|
||||||
|
|
||||||
if "*" in channel_key:
|
if "*" in channel_key:
|
||||||
await pubsub.psubscribe(full_channel_name)
|
await pubsub.psubscribe(full_channel_name)
|
||||||
|
|||||||
@@ -1028,6 +1028,39 @@ async def get_graph(
|
|||||||
return GraphModel.from_db(graph, for_export)
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||||
|
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||||
|
|
||||||
|
Only returns graphs that have approved store listings (publicly available).
|
||||||
|
Does not require permission checks since store-listed graphs are public.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*graph_ids: Variable number of graph IDs to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||||
|
"""
|
||||||
|
if not graph_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
store_listings = await StoreListingVersion.prisma().find_many(
|
||||||
|
where={
|
||||||
|
"agentGraphId": {"in": list(graph_ids)},
|
||||||
|
"submissionStatus": SubmissionStatus.APPROVED,
|
||||||
|
"isDeleted": False,
|
||||||
|
},
|
||||||
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
|
distinct=["agentGraphId"],
|
||||||
|
order={"agentGraphVersion": "desc"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
||||||
|
for listing in store_listings
|
||||||
|
if listing.AgentGraph
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_as_admin(
|
async def get_graph_as_admin(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
version: int | None = None,
|
version: int | None = None,
|
||||||
|
|||||||
@@ -666,10 +666,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
if not (self.discriminator and self.discriminator_mapping):
|
if not (self.discriminator and self.discriminator_mapping):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = self.discriminator_mapping[discriminator_value]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{discriminator_value}' is not supported. "
|
||||||
|
"It may have been deprecated. Please update your agent configuration."
|
||||||
|
)
|
||||||
|
|
||||||
return CredentialsFieldInfo(
|
return CredentialsFieldInfo(
|
||||||
credentials_provider=frozenset(
|
credentials_provider=frozenset([provider]),
|
||||||
[self.discriminator_mapping[discriminator_value]]
|
|
||||||
),
|
|
||||||
credentials_types=self.supported_types,
|
credentials_types=self.supported_types,
|
||||||
credentials_scopes=self.required_scopes,
|
credentials_scopes=self.required_scopes,
|
||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from backend.data.analytics import (
|
|||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
get_marketplace_graphs_for_monitoring,
|
get_marketplace_graphs_for_monitoring,
|
||||||
)
|
)
|
||||||
|
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
create_graph_execution,
|
create_graph_execution,
|
||||||
@@ -219,6 +220,9 @@ class DatabaseManager(AppService):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
@@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -24,11 +24,9 @@ from dotenv import load_dotenv
|
|||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import MetaData, create_engine
|
from sqlalchemy import MetaData, create_engine
|
||||||
|
|
||||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
|
||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
from backend.data.execution import GraphExecutionWithNodes
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.onboarding import increment_onboarding_runs
|
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.monitoring import (
|
from backend.monitoring import (
|
||||||
NotificationJobArgs,
|
NotificationJobArgs,
|
||||||
@@ -38,7 +36,11 @@ from backend.monitoring import (
|
|||||||
report_execution_accuracy_alerts,
|
report_execution_accuracy_alerts,
|
||||||
report_late_executions,
|
report_late_executions,
|
||||||
)
|
)
|
||||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
from backend.util.clients import (
|
||||||
|
get_database_manager_async_client,
|
||||||
|
get_database_manager_client,
|
||||||
|
get_scheduler_client,
|
||||||
|
)
|
||||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotFoundError,
|
GraphNotFoundError,
|
||||||
@@ -148,6 +150,7 @@ def execute_graph(**kwargs):
|
|||||||
async def _execute_graph(**kwargs):
|
async def _execute_graph(**kwargs):
|
||||||
args = GraphExecutionJobArgs(**kwargs)
|
args = GraphExecutionJobArgs(**kwargs)
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
db = get_database_manager_async_client()
|
||||||
try:
|
try:
|
||||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||||
@@ -157,7 +160,7 @@ async def _execute_graph(**kwargs):
|
|||||||
inputs=args.input_data,
|
inputs=args.input_data,
|
||||||
graph_credentials_inputs=args.input_credentials,
|
graph_credentials_inputs=args.input_credentials,
|
||||||
)
|
)
|
||||||
await increment_onboarding_runs(args.user_id)
|
await db.increment_onboarding_runs(args.user_id)
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||||
@@ -246,8 +249,13 @@ def cleanup_expired_files():
|
|||||||
|
|
||||||
def cleanup_oauth_tokens():
|
def cleanup_oauth_tokens():
|
||||||
"""Clean up expired OAuth tokens from the database."""
|
"""Clean up expired OAuth tokens from the database."""
|
||||||
|
|
||||||
# Wait for completion
|
# Wait for completion
|
||||||
run_async(cleanup_expired_oauth_tokens())
|
async def _cleanup():
|
||||||
|
db = get_database_manager_async_client()
|
||||||
|
return await db.cleanup_expired_oauth_tokens()
|
||||||
|
|
||||||
|
run_async(_cleanup())
|
||||||
|
|
||||||
|
|
||||||
def execution_accuracy_alerts():
|
def execution_accuracy_alerts():
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
|
from backend.api.features.integrations.router import router as integrations_router
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.integrations.webhooks import utils as webhooks_utils
|
||||||
|
|
||||||
|
|
||||||
|
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
app.include_router(integrations_router, prefix="/api/integrations")
|
||||||
|
|
||||||
|
provider = ProviderName.GITHUB
|
||||||
|
webhook_id = "webhook_123"
|
||||||
|
base_url = "https://example.com"
|
||||||
|
|
||||||
|
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
||||||
|
|
||||||
|
route = next(
|
||||||
|
route
|
||||||
|
for route in integrations_router.routes
|
||||||
|
if isinstance(route, APIRoute)
|
||||||
|
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
||||||
|
and "POST" in route.methods
|
||||||
|
)
|
||||||
|
expected_path = f"/api/integrations{route.path}".format(
|
||||||
|
provider=provider.value,
|
||||||
|
webhook_id=webhook_id,
|
||||||
|
)
|
||||||
|
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
||||||
|
expected_base = urlparse(base_url)
|
||||||
|
|
||||||
|
assert (actual_url.scheme, actual_url.netloc) == (
|
||||||
|
expected_base.scheme,
|
||||||
|
expected_base.netloc,
|
||||||
|
)
|
||||||
|
assert actual_url.path == expected_path
|
||||||
@@ -1,10 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
# CONSTANTS #
|
# CONSTANTS #
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
@@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool:
|
|||||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Carefully truncate tool message content while preserving tool structure.
|
Carefully truncate tool message content while preserving tool structure.
|
||||||
Only truncates tool_result content, leaves tool_use intact.
|
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
||||||
"""
|
"""
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# OpenAI-style tool message: role="tool" with string content
|
||||||
|
if msg.get("role") == "tool" and isinstance(content, str):
|
||||||
|
if _tok_len(content, enc) > max_tokens:
|
||||||
|
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Anthropic-style: list content with tool_result items
|
||||||
if not isinstance(content, list):
|
if not isinstance(content, list):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
|||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
|
||||||
def compress_prompt(
|
|
||||||
messages: list[dict],
|
|
||||||
target_tokens: int,
|
|
||||||
*,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
reserve: int = 2_048,
|
|
||||||
start_cap: int = 8_192,
|
|
||||||
floor_cap: int = 128,
|
|
||||||
lossy_ok: bool = True,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Shrink *messages* so that::
|
|
||||||
|
|
||||||
token_count(prompt) + reserve ≤ target_tokens
|
|
||||||
|
|
||||||
Strategy
|
|
||||||
--------
|
|
||||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
|
||||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
|
||||||
*content* of every message except the first and last. Tool shells
|
|
||||||
are included: we keep the envelope but shorten huge payloads.
|
|
||||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
|
||||||
messages working outward from the centre, **skipping** any message
|
|
||||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
|
||||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
|
||||||
*last* message bodies down to `floor_cap` tokens.
|
|
||||||
4. If the prompt is *still* too large:
|
|
||||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
|
||||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
messages Complete chat history (will be deep-copied).
|
|
||||||
model Model name; passed to tiktoken to pick the right
|
|
||||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
|
||||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
|
||||||
forthcoming answer.
|
|
||||||
reserve How many tokens you want to leave available for that
|
|
||||||
answer (`max_tokens` in your subsequent completion call).
|
|
||||||
start_cap Initial per-message truncation ceiling (tokens).
|
|
||||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
|
||||||
lossy_ok If *True* return best-effort prompt instead of raising
|
|
||||||
after all trim passes have been exhausted.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[dict] – A *new* messages list that abides by the rules above.
|
|
||||||
"""
|
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
|
||||||
msgs = deepcopy(messages) # never mutate caller
|
|
||||||
|
|
||||||
def total_tokens() -> int:
|
|
||||||
"""Current size of *msgs* in tokens."""
|
|
||||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
|
||||||
|
|
||||||
original_token_count = total_tokens()
|
|
||||||
|
|
||||||
if original_token_count + reserve <= target_tokens:
|
|
||||||
return msgs
|
|
||||||
|
|
||||||
# ---- STEP 0 : normalise content --------------------------------------
|
|
||||||
# Convert non-string payloads to strings so token counting is coherent.
|
|
||||||
for i, m in enumerate(msgs):
|
|
||||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
|
||||||
if _is_tool_message(m):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Keep first and last messages intact (unless they're tool messages)
|
|
||||||
if i == 0 or i == len(msgs) - 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
|
||||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
|
||||||
if len(content_str) > 20_000:
|
|
||||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
|
||||||
m["content"] = content_str
|
|
||||||
|
|
||||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for m in msgs[1:-1]: # keep first & last intact
|
|
||||||
if _is_tool_message(m):
|
|
||||||
# For tool messages, only truncate tool result content, preserve structure
|
|
||||||
_truncate_tool_message_content(m, enc, cap)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _is_objective_message(m):
|
|
||||||
# Never truncate objective messages - they contain the core task
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = m.get("content") or ""
|
|
||||||
if _tok_len(content, enc) > cap:
|
|
||||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
|
||||||
cap //= 2 # tighten the screw
|
|
||||||
|
|
||||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
|
||||||
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
|
||||||
deletable_indices = []
|
|
||||||
for i in range(1, len(msgs) - 1): # Skip first and last
|
|
||||||
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
|
||||||
deletable_indices.append(i)
|
|
||||||
|
|
||||||
if not deletable_indices:
|
|
||||||
break # nothing more we can drop
|
|
||||||
|
|
||||||
# Delete from center outward - find the index closest to center
|
|
||||||
centre = len(msgs) // 2
|
|
||||||
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
|
||||||
del msgs[to_delete]
|
|
||||||
|
|
||||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for idx in (0, -1): # first and last
|
|
||||||
if _is_tool_message(msgs[idx]):
|
|
||||||
# For tool messages at first/last position, truncate tool result content only
|
|
||||||
_truncate_tool_message_content(msgs[idx], enc, cap)
|
|
||||||
continue
|
|
||||||
|
|
||||||
text = msgs[idx].get("content") or ""
|
|
||||||
if _tok_len(text, enc) > cap:
|
|
||||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
|
||||||
cap //= 2 # tighten the screw
|
|
||||||
|
|
||||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
|
||||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
|
||||||
raise ValueError(
|
|
||||||
"compress_prompt: prompt still exceeds budget "
|
|
||||||
f"({total_tokens() + reserve} > {target_tokens})."
|
|
||||||
)
|
|
||||||
|
|
||||||
return msgs
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(
|
def estimate_token_count(
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
*,
|
*,
|
||||||
@@ -293,7 +175,8 @@ def estimate_token_count(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
return sum(_msg_tokens(m, enc) for m in messages)
|
return sum(_msg_tokens(m, enc) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +198,543 @@ def estimate_token_count_str(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
text = json.dumps(text) if not isinstance(text, str) else text
|
text = json.dumps(text) if not isinstance(text, str) else text
|
||||||
return _tok_len(text, enc)
|
return _tok_len(text, enc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
# UNIFIED CONTEXT COMPRESSION #
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
# Default thresholds
|
||||||
|
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||||
|
DEFAULT_KEEP_RECENT = 15
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompressResult:
|
||||||
|
"""Result of context compression."""
|
||||||
|
|
||||||
|
messages: list[dict]
|
||||||
|
token_count: int
|
||||||
|
was_compacted: bool
|
||||||
|
error: str | None = None
|
||||||
|
original_token_count: int = 0
|
||||||
|
messages_summarized: int = 0
|
||||||
|
messages_dropped: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_for_tokenizer(model: str) -> str:
|
||||||
|
"""Normalize model name for tiktoken tokenizer selection."""
|
||||||
|
if "/" in model:
|
||||||
|
model = model.split("/")[-1]
|
||||||
|
if "claude" in model.lower() or not any(
|
||||||
|
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||||
|
):
|
||||||
|
return "gpt-4o"
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
||||||
|
"""
|
||||||
|
Extract tool_call IDs from an assistant message.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
||||||
|
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of tool_call IDs found in the message.
|
||||||
|
"""
|
||||||
|
ids: set[str] = set()
|
||||||
|
if msg.get("role") != "assistant":
|
||||||
|
return ids
|
||||||
|
|
||||||
|
# OpenAI format: tool_calls array
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
# Anthropic format: content list with tool_use blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
|
tc_id = block.get("id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
||||||
|
"""
|
||||||
|
Extract tool_call IDs that this message is responding to.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
||||||
|
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of tool_call IDs this message responds to.
|
||||||
|
"""
|
||||||
|
ids: set[str] = set()
|
||||||
|
|
||||||
|
# OpenAI format: role=tool with tool_call_id
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
# Anthropic format: content list with tool_result blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||||
|
tc_id = block.get("tool_use_id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_response_message(msg: dict) -> bool:
|
||||||
|
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
||||||
|
# OpenAI format
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
return True
|
||||||
|
# Anthropic format
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_orphan_tool_responses(
|
||||||
|
messages: list[dict], orphan_ids: set[str]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
||||||
|
|
||||||
|
Supports both OpenAI and Anthropic formats.
|
||||||
|
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
||||||
|
filters out only the orphan blocks instead of dropping the entire message.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
# OpenAI format: role=tool - drop entire message if orphan
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id and tc_id in orphan_ids:
|
||||||
|
continue
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Anthropic format: content list may have mixed tool_result blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
has_tool_results = any(
|
||||||
|
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
||||||
|
)
|
||||||
|
if has_tool_results:
|
||||||
|
# Filter out orphan tool_result blocks, keep valid ones
|
||||||
|
filtered_content = [
|
||||||
|
block
|
||||||
|
for block in content
|
||||||
|
if not (
|
||||||
|
isinstance(block, dict)
|
||||||
|
and block.get("type") == "tool_result"
|
||||||
|
and block.get("tool_use_id") in orphan_ids
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# Only keep message if it has remaining content
|
||||||
|
if filtered_content:
|
||||||
|
msg = msg.copy()
|
||||||
|
msg["content"] = filtered_content
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tool_pairs_intact(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
all_messages: list[dict],
|
||||||
|
start_index: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||||
|
|
||||||
|
When slicing messages for context compaction, a naive slice can separate
|
||||||
|
an assistant message containing tool_calls from its corresponding tool
|
||||||
|
response messages. This causes API validation errors (e.g., Anthropic's
|
||||||
|
"unexpected tool_use_id found in tool_result blocks").
|
||||||
|
|
||||||
|
This function checks for orphan tool responses in the slice and extends
|
||||||
|
backwards to include their corresponding assistant messages.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: tool_calls array + role="tool" responses
|
||||||
|
- Anthropic: tool_use blocks + tool_result blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: The sliced messages to validate
|
||||||
|
all_messages: The complete message list (for looking up missing assistants)
|
||||||
|
start_index: The index in all_messages where recent_messages begins
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially extended list of messages with tool pairs intact
|
||||||
|
"""
|
||||||
|
if not recent_messages:
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Collect all tool_call_ids from assistant messages in the slice
|
||||||
|
available_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
||||||
|
|
||||||
|
# Find orphan tool responses (responses whose tool_call_id is missing)
|
||||||
|
orphan_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
response_ids = _extract_tool_response_ids_from_message(msg)
|
||||||
|
for tc_id in response_ids:
|
||||||
|
if tc_id not in available_tool_call_ids:
|
||||||
|
orphan_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# No orphans, slice is valid
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Find the assistant messages that contain the orphan tool_call_ids
|
||||||
|
# Search backwards from start_index in all_messages
|
||||||
|
messages_to_prepend: list[dict] = []
|
||||||
|
for i in range(start_index - 1, -1, -1):
|
||||||
|
msg = all_messages[i]
|
||||||
|
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
||||||
|
if msg_tool_ids & orphan_tool_call_ids:
|
||||||
|
# This assistant message has tool_calls we need
|
||||||
|
# Also collect its contiguous tool responses that follow it
|
||||||
|
assistant_and_responses: list[dict] = [msg]
|
||||||
|
|
||||||
|
# Scan forward from this assistant to collect tool responses
|
||||||
|
for j in range(i + 1, start_index):
|
||||||
|
following_msg = all_messages[j]
|
||||||
|
following_response_ids = _extract_tool_response_ids_from_message(
|
||||||
|
following_msg
|
||||||
|
)
|
||||||
|
if following_response_ids and following_response_ids & msg_tool_ids:
|
||||||
|
assistant_and_responses.append(following_msg)
|
||||||
|
elif not _is_tool_response_message(following_msg):
|
||||||
|
# Stop at first non-tool-response message
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepend the assistant and its tool responses (maintain order)
|
||||||
|
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||||
|
# Mark these as found
|
||||||
|
orphan_tool_call_ids -= msg_tool_ids
|
||||||
|
# Also add this assistant's tool_call_ids to available set
|
||||||
|
available_tool_call_ids |= msg_tool_ids
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# Found all missing assistants
|
||||||
|
break
|
||||||
|
|
||||||
|
if orphan_tool_call_ids:
|
||||||
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||||
|
"Removing orphan tool responses."
|
||||||
|
)
|
||||||
|
recent_messages = _remove_orphan_tool_responses(
|
||||||
|
recent_messages, orphan_tool_call_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if messages_to_prepend:
|
||||||
|
logger.info(
|
||||||
|
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||||
|
f"tool_call/tool_response pairs"
|
||||||
|
)
|
||||||
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
|
||||||
|
async def _summarize_messages_llm(
|
||||||
|
messages: list[dict],
|
||||||
|
client: AsyncOpenAI,
|
||||||
|
model: str,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> str:
|
||||||
|
"""Summarize messages using an LLM."""
|
||||||
|
conversation = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if content and role in ("user", "assistant", "tool"):
|
||||||
|
conversation.append(f"{role.upper()}: {content}")
|
||||||
|
|
||||||
|
conversation_text = "\n\n".join(conversation)
|
||||||
|
|
||||||
|
if not conversation_text:
|
||||||
|
return "No conversation history available."
|
||||||
|
|
||||||
|
# Limit to ~100k chars for safety
|
||||||
|
MAX_CHARS = 100_000
|
||||||
|
if len(conversation_text) > MAX_CHARS:
|
||||||
|
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||||
|
|
||||||
|
response = await client.with_options(timeout=timeout).chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Create a detailed summary of the conversation so far. "
|
||||||
|
"This summary will be used as context when continuing the conversation.\n\n"
|
||||||
|
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||||
|
"- User requests and their explicit goals\n"
|
||||||
|
"- Your approach and key decisions made\n"
|
||||||
|
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||||
|
"- Errors encountered and resolutions applied\n\n"
|
||||||
|
"You MUST include ALL of the following sections:\n\n"
|
||||||
|
"## 1. Primary Request and Intent\n"
|
||||||
|
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||||
|
"## 2. Key Technical Concepts\n"
|
||||||
|
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||||
|
"## 3. Files and Resources Involved\n"
|
||||||
|
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||||
|
"## 4. Errors and Fixes\n"
|
||||||
|
"Problems encountered, error messages, and their resolutions. "
|
||||||
|
"Include any user feedback on fixes.\n\n"
|
||||||
|
"## 5. Problem Solving\n"
|
||||||
|
"Issues that have been resolved and how they were addressed.\n\n"
|
||||||
|
"## 6. All User Messages\n"
|
||||||
|
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||||
|
"## 7. Pending Tasks\n"
|
||||||
|
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||||
|
"## 8. Current Work\n"
|
||||||
|
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||||
|
"## 9. Next Steps\n"
|
||||||
|
"What should happen next, aligned with the user's most recent requests. "
|
||||||
|
"Include verbatim quotes of recent instructions if relevant."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||||
|
],
|
||||||
|
max_tokens=1500,
|
||||||
|
temperature=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content or "No summary available."
|
||||||
|
|
||||||
|
|
||||||
|
async def compress_context(
|
||||||
|
messages: list[dict],
|
||||||
|
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
client: AsyncOpenAI | None = None,
|
||||||
|
keep_recent: int = DEFAULT_KEEP_RECENT,
|
||||||
|
reserve: int = 2_048,
|
||||||
|
start_cap: int = 8_192,
|
||||||
|
floor_cap: int = 128,
|
||||||
|
) -> CompressResult:
|
||||||
|
"""
|
||||||
|
Unified context compression that combines summarization and truncation strategies.
|
||||||
|
|
||||||
|
Strategy (in order):
|
||||||
|
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||||
|
single context message while keeping recent messages intact. This is the
|
||||||
|
primary strategy for chat service.
|
||||||
|
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
||||||
|
bloated message content (tool outputs, large pastes). Preserves all messages
|
||||||
|
but shortens their content. Primary strategy when client=None (LLM blocks).
|
||||||
|
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
||||||
|
outward, skipping tool messages and objective messages.
|
||||||
|
4. **First/last trim** – Truncate first and last message content as last resort.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
messages Complete chat history (will be deep-copied).
|
||||||
|
target_tokens Hard ceiling for prompt size.
|
||||||
|
model Model name for tokenization and summarization.
|
||||||
|
client AsyncOpenAI client. If provided, enables LLM summarization
|
||||||
|
as the first strategy. If None, skips to truncation strategies.
|
||||||
|
keep_recent Number of recent messages to preserve during summarization.
|
||||||
|
reserve Tokens to reserve for model response.
|
||||||
|
start_cap Initial per-message truncation ceiling (tokens).
|
||||||
|
floor_cap Lowest cap before moving to deletions.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CompressResult with compressed messages and metadata.
|
||||||
|
"""
|
||||||
|
# Guard clause for empty messages
|
||||||
|
if not messages:
|
||||||
|
return CompressResult(
|
||||||
|
messages=[],
|
||||||
|
token_count=0,
|
||||||
|
was_compacted=False,
|
||||||
|
original_token_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
|
msgs = deepcopy(messages)
|
||||||
|
|
||||||
|
def total_tokens() -> int:
|
||||||
|
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||||
|
|
||||||
|
original_count = total_tokens()
|
||||||
|
|
||||||
|
# Already under limit
|
||||||
|
if original_count + reserve <= target_tokens:
|
||||||
|
return CompressResult(
|
||||||
|
messages=msgs,
|
||||||
|
token_count=original_count,
|
||||||
|
was_compacted=False,
|
||||||
|
original_token_count=original_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_summarized = 0
|
||||||
|
messages_dropped = 0
|
||||||
|
|
||||||
|
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
||||||
|
# This is the primary compression strategy for chat service.
|
||||||
|
# Summarize old messages while keeping recent ones intact.
|
||||||
|
if client is not None:
|
||||||
|
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
||||||
|
system_msg = msgs[0] if has_system else None
|
||||||
|
|
||||||
|
# Calculate old vs recent messages
|
||||||
|
if has_system:
|
||||||
|
if len(msgs) > keep_recent + 1:
|
||||||
|
old_msgs = msgs[1:-keep_recent]
|
||||||
|
recent_msgs = msgs[-keep_recent:]
|
||||||
|
else:
|
||||||
|
old_msgs = []
|
||||||
|
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
||||||
|
else:
|
||||||
|
if len(msgs) > keep_recent:
|
||||||
|
old_msgs = msgs[:-keep_recent]
|
||||||
|
recent_msgs = msgs[-keep_recent:]
|
||||||
|
else:
|
||||||
|
old_msgs = []
|
||||||
|
recent_msgs = msgs
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact
|
||||||
|
slice_start = max(0, len(msgs) - keep_recent)
|
||||||
|
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
||||||
|
|
||||||
|
if old_msgs:
|
||||||
|
try:
|
||||||
|
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
||||||
|
summary_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
||||||
|
}
|
||||||
|
messages_summarized = len(old_msgs)
|
||||||
|
|
||||||
|
if has_system:
|
||||||
|
msgs = [system_msg, summary_msg] + recent_msgs
|
||||||
|
else:
|
||||||
|
msgs = [summary_msg] + recent_msgs
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||||
|
f"summarized {messages_summarized} messages"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||||
|
# Fall through to content truncation
|
||||||
|
|
||||||
|
# ---- STEP 2: Normalize content ----------------------------------------
|
||||||
|
# Convert non-string payloads to strings so token counting is coherent.
|
||||||
|
# Always run this before truncation to ensure consistent token counting.
|
||||||
|
for i, m in enumerate(msgs):
|
||||||
|
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
continue
|
||||||
|
if i == 0 or i == len(msgs) - 1:
|
||||||
|
continue
|
||||||
|
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||||
|
if len(content_str) > 20_000:
|
||||||
|
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||||
|
m["content"] = content_str
|
||||||
|
|
||||||
|
# ---- STEP 3: Token-aware content truncation ---------------------------
|
||||||
|
# Progressively halve per-message cap and truncate bloated content.
|
||||||
|
# This preserves all messages but shortens their content.
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for m in msgs[1:-1]:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
_truncate_tool_message_content(m, enc, cap)
|
||||||
|
continue
|
||||||
|
if _is_objective_message(m):
|
||||||
|
continue
|
||||||
|
content = m.get("content") or ""
|
||||||
|
if _tok_len(content, enc) > cap:
|
||||||
|
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||||
|
cap //= 2
|
||||||
|
|
||||||
|
# ---- STEP 4: Middle-out deletion --------------------------------------
|
||||||
|
# Delete messages one at a time from the center outward.
|
||||||
|
# This is more granular than dropping all old messages at once.
|
||||||
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
|
deletable: list[int] = []
|
||||||
|
for i in range(1, len(msgs) - 1):
|
||||||
|
msg = msgs[i]
|
||||||
|
if (
|
||||||
|
msg is not None
|
||||||
|
and not _is_tool_message(msg)
|
||||||
|
and not _is_objective_message(msg)
|
||||||
|
):
|
||||||
|
deletable.append(i)
|
||||||
|
if not deletable:
|
||||||
|
break
|
||||||
|
centre = len(msgs) // 2
|
||||||
|
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
||||||
|
del msgs[to_delete]
|
||||||
|
messages_dropped += 1
|
||||||
|
|
||||||
|
# ---- STEP 5: Final trim on first/last ---------------------------------
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for idx in (0, -1):
|
||||||
|
msg = msgs[idx]
|
||||||
|
if msg is None:
|
||||||
|
continue
|
||||||
|
if _is_tool_message(msg):
|
||||||
|
_truncate_tool_message_content(msg, enc, cap)
|
||||||
|
continue
|
||||||
|
text = msg.get("content") or ""
|
||||||
|
if _tok_len(text, enc) > cap:
|
||||||
|
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||||
|
cap //= 2
|
||||||
|
|
||||||
|
# Filter out any None values that may have been introduced
|
||||||
|
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||||
|
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||||
|
error = None
|
||||||
|
if final_count + reserve > target_tokens:
|
||||||
|
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
||||||
|
logger.warning(error)
|
||||||
|
|
||||||
|
return CompressResult(
|
||||||
|
messages=final_msgs,
|
||||||
|
token_count=final_count,
|
||||||
|
was_compacted=True,
|
||||||
|
error=error,
|
||||||
|
original_token_count=original_count,
|
||||||
|
messages_summarized=messages_summarized,
|
||||||
|
messages_dropped=messages_dropped,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,21 @@
|
|||||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
from backend.util.prompt import (
|
||||||
|
CompressResult,
|
||||||
|
_ensure_tool_pairs_intact,
|
||||||
|
_msg_tokens,
|
||||||
|
_normalize_model_for_tokenizer,
|
||||||
|
_truncate_middle_tokens,
|
||||||
|
_truncate_tool_message_content,
|
||||||
|
compress_context,
|
||||||
|
estimate_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMsgTokens:
|
class TestMsgTokens:
|
||||||
@@ -276,3 +287,690 @@ class TestEstimateTokenCount:
|
|||||||
|
|
||||||
assert total_tokens == expected_total
|
assert total_tokens == expected_total
|
||||||
assert total_tokens > 20 # Should be substantial
|
assert total_tokens > 20 # Should be substantial
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeModelForTokenizer:
|
||||||
|
"""Test model name normalization for tiktoken."""
|
||||||
|
|
||||||
|
def test_openai_models_unchanged(self):
|
||||||
|
"""Test that OpenAI models are returned as-is."""
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
def test_claude_models_normalized(self):
|
||||||
|
"""Test that Claude models are normalized to gpt-4o."""
|
||||||
|
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
||||||
|
|
||||||
|
def test_openrouter_paths_extracted(self):
|
||||||
|
"""Test that OpenRouter model paths are handled."""
|
||||||
|
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
||||||
|
|
||||||
|
def test_unknown_models_default_to_gpt4o(self):
|
||||||
|
"""Test that unknown models default to gpt-4o."""
|
||||||
|
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateToolMessageContent:
|
||||||
|
"""Test tool message content truncation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enc(self):
|
||||||
|
return encoding_for_model("gpt-4o")
|
||||||
|
|
||||||
|
def test_truncate_openai_tool_message(self, enc):
|
||||||
|
"""Test truncation of OpenAI-style tool message with string content."""
|
||||||
|
long_content = "x" * 10000
|
||||||
|
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
||||||
|
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||||
|
|
||||||
|
# Content should be truncated
|
||||||
|
assert len(msg["content"]) < len(long_content)
|
||||||
|
assert "…" in msg["content"] # Has ellipsis marker
|
||||||
|
|
||||||
|
def test_truncate_anthropic_tool_result(self, enc):
|
||||||
|
"""Test truncation of Anthropic-style tool_result."""
|
||||||
|
long_content = "y" * 10000
|
||||||
|
msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_123",
|
||||||
|
"content": long_content,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||||
|
|
||||||
|
# Content should be truncated
|
||||||
|
result_content = msg["content"][0]["content"]
|
||||||
|
assert len(result_content) < len(long_content)
|
||||||
|
assert "…" in result_content
|
||||||
|
|
||||||
|
def test_preserve_tool_use_blocks(self, enc):
|
||||||
|
"""Test that tool_use blocks are not truncated."""
|
||||||
|
msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_123",
|
||||||
|
"name": "some_function",
|
||||||
|
"input": {"key": "value" * 1000}, # Large input
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
original = json.dumps(msg["content"][0]["input"])
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
||||||
|
|
||||||
|
# tool_use should be unchanged
|
||||||
|
assert json.dumps(msg["content"][0]["input"]) == original
|
||||||
|
|
||||||
|
def test_no_truncation_when_under_limit(self, enc):
|
||||||
|
"""Test that short content is not modified."""
|
||||||
|
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
||||||
|
|
||||||
|
original = msg["content"]
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
||||||
|
|
||||||
|
assert msg["content"] == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateMiddleTokens:
|
||||||
|
"""Test middle truncation of text."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enc(self):
|
||||||
|
return encoding_for_model("gpt-4o")
|
||||||
|
|
||||||
|
def test_truncates_long_text(self, enc):
|
||||||
|
"""Test that long text is truncated with ellipsis in middle."""
|
||||||
|
long_text = "word " * 1000
|
||||||
|
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
||||||
|
|
||||||
|
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
||||||
|
assert "…" in result
|
||||||
|
assert result.startswith("word") # Head preserved
|
||||||
|
assert result.endswith("word ") # Tail preserved
|
||||||
|
|
||||||
|
def test_preserves_short_text(self, enc):
|
||||||
|
"""Test that short text is not modified."""
|
||||||
|
short_text = "Hello world"
|
||||||
|
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
||||||
|
|
||||||
|
assert result == short_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureToolPairsIntact:
|
||||||
|
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
||||||
|
|
||||||
|
# ---- OpenAI Format Tests ----
|
||||||
|
|
||||||
|
def test_openai_adds_missing_tool_call(self):
|
||||||
|
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool response)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_call message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert "tool_calls" in result[0]
|
||||||
|
|
||||||
|
def test_openai_keeps_complete_pairs(self):
|
||||||
|
"""Test that complete OpenAI pairs are unchanged."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||||
|
]
|
||||||
|
recent = all_msgs[1:] # Include both tool_call and response
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert len(result) == 2 # No messages added
|
||||||
|
|
||||||
|
def test_openai_multiple_tool_calls(self):
|
||||||
|
"""Test multiple OpenAI tool calls in one assistant message."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
||||||
|
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (first tool response)
|
||||||
|
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with both tool_calls
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert len(result[0]["tool_calls"]) == 2
|
||||||
|
|
||||||
|
# ---- Anthropic Format Tests ----
|
||||||
|
|
||||||
|
def test_anthropic_adds_missing_tool_use(self):
|
||||||
|
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "SF"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_123",
|
||||||
|
"content": "22°C and sunny",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool_result)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_use message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert result[0]["content"][0]["type"] == "tool_use"
|
||||||
|
|
||||||
|
def test_anthropic_keeps_complete_pairs(self):
|
||||||
|
"""Test that complete Anthropic pairs are unchanged."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_456",
|
||||||
|
"name": "calculator",
|
||||||
|
"input": {"expr": "2+2"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_456",
|
||||||
|
"content": "4",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
recent = all_msgs[1:] # Include both tool_use and result
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert len(result) == 2 # No messages added
|
||||||
|
|
||||||
|
def test_anthropic_multiple_tool_uses(self):
|
||||||
|
"""Test multiple Anthropic tool_use blocks in one message."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me check both..."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"city": "NYC"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_2",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"city": "LA"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_1",
|
||||||
|
"content": "Cold",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_2",
|
||||||
|
"content": "Warm",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (tool_result)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with both tool_uses
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
tool_use_count = sum(
|
||||||
|
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
||||||
|
)
|
||||||
|
assert tool_use_count == 2
|
||||||
|
|
||||||
|
# ---- Mixed/Edge Case Tests ----
|
||||||
|
|
||||||
|
def test_anthropic_with_type_message_field(self):
|
||||||
|
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_abc",
|
||||||
|
"name": "search",
|
||||||
|
"input": {"q": "test"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"type": "message", # Extra field from smart_decision_maker
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_abc",
|
||||||
|
"content": "Found results",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_use message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert result[0]["content"][0]["type"] == "tool_use"
|
||||||
|
|
||||||
|
def test_handles_no_tool_messages(self):
|
||||||
|
"""Test messages without tool calls."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
recent = all_msgs
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert result == all_msgs
|
||||||
|
|
||||||
|
def test_handles_empty_messages(self):
|
||||||
|
"""Test empty message list."""
|
||||||
|
result = _ensure_tool_pairs_intact([], [], 0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_mixed_text_and_tool_content(self):
|
||||||
|
"""Test Anthropic message with mixed text and tool_use content."""
|
||||||
|
all_msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "I'll help you with that."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_mixed",
|
||||||
|
"name": "search",
|
||||||
|
"input": {"q": "test"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_mixed",
|
||||||
|
"content": "Found results",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Here are the results..."},
|
||||||
|
]
|
||||||
|
# Start from tool_result
|
||||||
|
recent = [all_msgs[1], all_msgs[2]]
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with tool_use
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["content"][0]["type"] == "text"
|
||||||
|
assert result[0]["content"][1]["type"] == "tool_use"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressContext:
|
||||||
|
"""Test the async compress_context function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_compression_needed(self):
|
||||||
|
"""Test messages under limit return without compression."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(messages, target_tokens=100000)
|
||||||
|
|
||||||
|
assert isinstance(result, CompressResult)
|
||||||
|
assert result.was_compacted is False
|
||||||
|
assert len(result.messages) == 2
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_without_client(self):
|
||||||
|
"""Test that truncation works without LLM client."""
|
||||||
|
long_content = "x" * 50000
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{"role": "user", "content": long_content},
|
||||||
|
{"role": "assistant", "content": "Response"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=1000, client=None, reserve=100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.was_compacted is True
|
||||||
|
# Should have truncated without summarization
|
||||||
|
assert result.messages_summarized == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_mocked_llm_client(self):
|
||||||
|
"""Test summarization with mocked LLM client."""
|
||||||
|
# Create many messages to trigger summarization
|
||||||
|
messages = [{"role": "system", "content": "System prompt"}]
|
||||||
|
for i in range(30):
|
||||||
|
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
||||||
|
messages.append(
|
||||||
|
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the AsyncOpenAI client
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Summary of conversation"
|
||||||
|
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages,
|
||||||
|
target_tokens=5000,
|
||||||
|
client=mock_client,
|
||||||
|
keep_recent=5,
|
||||||
|
reserve=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.was_compacted is True
|
||||||
|
# Should have attempted summarization
|
||||||
|
assert mock_client.with_options.called or result.messages_summarized > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_tool_pairs(self):
|
||||||
|
"""Test that tool call/response pairs stay together."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{"role": "user", "content": "Do something"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
||||||
|
{"role": "assistant", "content": "Done!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=500, client=None, reserve=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that if tool response exists, its call exists too
|
||||||
|
tool_call_ids = set()
|
||||||
|
tool_response_ids = set()
|
||||||
|
for msg in result.messages:
|
||||||
|
if "tool_calls" in msg:
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tool_call_ids.add(tc["id"])
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tool_response_ids.add(msg.get("tool_call_id"))
|
||||||
|
|
||||||
|
# All tool responses should have their calls
|
||||||
|
assert tool_response_ids <= tool_call_ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_error_when_cannot_compress(self):
|
||||||
|
"""Test that error is returned when compression fails."""
|
||||||
|
# Single huge message that can't be compressed enough
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "x" * 100000},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=100, client=None, reserve=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have an error since we can't get below 100 tokens
|
||||||
|
assert result.error is not None
|
||||||
|
assert result.was_compacted is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_messages(self):
|
||||||
|
"""Test that empty messages list returns early without error."""
|
||||||
|
result = await compress_context([], target_tokens=1000)
|
||||||
|
|
||||||
|
assert result.messages == []
|
||||||
|
assert result.token_count == 0
|
||||||
|
assert result.was_compacted is False
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveOrphanToolResponses:
|
||||||
|
"""Test _remove_orphan_tool_responses helper function."""
|
||||||
|
|
||||||
|
def test_removes_openai_orphan(self):
|
||||||
|
"""Test removal of orphan OpenAI tool response."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"call_orphan"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_keeps_valid_openai_tool(self):
|
||||||
|
"""Test that valid OpenAI tool responses are kept."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"call_other"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["tool_call_id"] == "call_valid"
|
||||||
|
|
||||||
|
def test_filters_anthropic_mixed_blocks(self):
|
||||||
|
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_valid",
|
||||||
|
"content": "valid result",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan",
|
||||||
|
"content": "orphan result",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
orphan_ids = {"toolu_orphan"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
# Should only have the valid tool_result, orphan filtered out
|
||||||
|
assert len(result[0]["content"]) == 1
|
||||||
|
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
||||||
|
|
||||||
|
def test_removes_anthropic_all_orphan(self):
|
||||||
|
"""Test removal of Anthropic message when all tool_results are orphans."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan1",
|
||||||
|
"content": "result1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan2",
|
||||||
|
"content": "result2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
# Message should be completely removed since no content left
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_preserves_non_tool_messages(self):
|
||||||
|
"""Test that non-tool messages are preserved."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"some_id"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressResultDataclass:
|
||||||
|
"""Test CompressResult dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values are set correctly."""
|
||||||
|
result = CompressResult(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
token_count=10,
|
||||||
|
was_compacted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.original_token_count == 0 # Defaults to 0, not None
|
||||||
|
assert result.messages_summarized == 0
|
||||||
|
assert result.messages_dropped == 0
|
||||||
|
|
||||||
|
def test_all_fields(self):
|
||||||
|
"""Test all fields can be set."""
|
||||||
|
result = CompressResult(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
token_count=100,
|
||||||
|
was_compacted=True,
|
||||||
|
error="Some error",
|
||||||
|
original_token_count=500,
|
||||||
|
messages_summarized=10,
|
||||||
|
messages_dropped=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.token_count == 100
|
||||||
|
assert result.was_compacted is True
|
||||||
|
assert result.error == "Some error"
|
||||||
|
assert result.original_token_count == 500
|
||||||
|
assert result.messages_summarized == 10
|
||||||
|
assert result.messages_dropped == 5
|
||||||
|
|||||||
32
autogpt_platform/backend/backend/util/validation.py
Normal file
32
autogpt_platform/backend/backend/util/validation.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Validation utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_UUID_V4_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uuid_v4(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the text is a valid UUID v4, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to search for UUIDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found (lowercase)
|
||||||
|
"""
|
||||||
|
return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
|
||||||
|
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
|
||||||
|
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
|
||||||
|
|
||||||
|
-- Update AgentNode constant inputs
|
||||||
|
UPDATE "AgentNode"
|
||||||
|
SET "constantInput" = JSONB_SET(
|
||||||
|
"constantInput"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
|
|
||||||
|
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||||
|
UPDATE "AgentNodeExecutionInputOutput"
|
||||||
|
SET "data" = JSONB_SET(
|
||||||
|
"data"::jsonb,
|
||||||
|
'{model}',
|
||||||
|
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||||
|
)
|
||||||
|
WHERE "agentPresetId" IS NOT NULL
|
||||||
|
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Creator agent subheading",
|
"sub_heading": "Creator agent subheading",
|
||||||
"description": "Creator agent description",
|
"description": "Creator agent description",
|
||||||
"runs": 50,
|
"runs": 50,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Category agent subheading",
|
"sub_heading": "Category agent subheading",
|
||||||
"description": "Category agent description",
|
"description": "Category agent description",
|
||||||
"runs": 60,
|
"runs": 60,
|
||||||
"rating": 4.1
|
"rating": 4.1,
|
||||||
|
"agent_graph_id": "test-graph-category"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Agent 0 subheading",
|
"sub_heading": "Agent 0 subheading",
|
||||||
"description": "Agent 0 description",
|
"description": "Agent 0 description",
|
||||||
"runs": 0,
|
"runs": 0,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-1",
|
"slug": "agent-1",
|
||||||
@@ -20,7 +21,8 @@
|
|||||||
"sub_heading": "Agent 1 subheading",
|
"sub_heading": "Agent 1 subheading",
|
||||||
"description": "Agent 1 description",
|
"description": "Agent 1 description",
|
||||||
"runs": 10,
|
"runs": 10,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-2",
|
"slug": "agent-2",
|
||||||
@@ -31,7 +33,8 @@
|
|||||||
"sub_heading": "Agent 2 subheading",
|
"sub_heading": "Agent 2 subheading",
|
||||||
"description": "Agent 2 description",
|
"description": "Agent 2 description",
|
||||||
"runs": 20,
|
"runs": 20,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-3",
|
"slug": "agent-3",
|
||||||
@@ -42,7 +45,8 @@
|
|||||||
"sub_heading": "Agent 3 subheading",
|
"sub_heading": "Agent 3 subheading",
|
||||||
"description": "Agent 3 description",
|
"description": "Agent 3 description",
|
||||||
"runs": 30,
|
"runs": 30,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-4",
|
"slug": "agent-4",
|
||||||
@@ -53,7 +57,8 @@
|
|||||||
"sub_heading": "Agent 4 subheading",
|
"sub_heading": "Agent 4 subheading",
|
||||||
"description": "Agent 4 description",
|
"description": "Agent 4 description",
|
||||||
"runs": 40,
|
"runs": 40,
|
||||||
"rating": 4.0
|
"rating": 4.0,
|
||||||
|
"agent_graph_id": "test-graph-2"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Search agent subheading",
|
"sub_heading": "Search agent subheading",
|
||||||
"description": "Specific search term description",
|
"description": "Specific search term description",
|
||||||
"runs": 75,
|
"runs": 75,
|
||||||
"rating": 4.2
|
"rating": 4.2,
|
||||||
|
"agent_graph_id": "test-graph-search"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Top agent subheading",
|
"sub_heading": "Top agent subheading",
|
||||||
"description": "Top agent description",
|
"description": "Top agent description",
|
||||||
"runs": 1000,
|
"runs": 1000,
|
||||||
"rating": 5.0
|
"rating": 5.0,
|
||||||
|
"agent_graph_id": "test-graph-3"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"sub_heading": "Featured agent subheading",
|
"sub_heading": "Featured agent subheading",
|
||||||
"description": "Featured agent description",
|
"description": "Featured agent description",
|
||||||
"runs": 100,
|
"runs": 100,
|
||||||
"rating": 4.5
|
"rating": 4.5,
|
||||||
|
"agent_graph_id": "test-graph-1"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -31,6 +31,10 @@
|
|||||||
"has_sensitive_action": false,
|
"has_sensitive_action": false,
|
||||||
"trigger_setup_info": null,
|
"trigger_setup_info": null,
|
||||||
"new_output": false,
|
"new_output": false,
|
||||||
|
"execution_count": 0,
|
||||||
|
"success_rate": null,
|
||||||
|
"avg_correctness_score": null,
|
||||||
|
"recent_executions": [],
|
||||||
"can_access_graph": true,
|
"can_access_graph": true,
|
||||||
"is_latest_version": true,
|
"is_latest_version": true,
|
||||||
"is_favorite": false,
|
"is_favorite": false,
|
||||||
@@ -72,6 +76,10 @@
|
|||||||
"has_sensitive_action": false,
|
"has_sensitive_action": false,
|
||||||
"trigger_setup_info": null,
|
"trigger_setup_info": null,
|
||||||
"new_output": false,
|
"new_output": false,
|
||||||
|
"execution_count": 0,
|
||||||
|
"success_rate": null,
|
||||||
|
"avg_correctness_score": null,
|
||||||
|
"recent_executions": [],
|
||||||
"can_access_graph": false,
|
"can_access_graph": false,
|
||||||
"is_latest_version": true,
|
"is_latest_version": true,
|
||||||
"is_favorite": false,
|
"is_favorite": false,
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
result = await core.decompose_goal("Build a chatbot")
|
result = await core.decompose_goal("Build a chatbot")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -74,7 +75,8 @@ class TestDecomposeGoal:
|
|||||||
|
|
||||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_on_service_failure(self):
|
async def test_returns_none_on_service_failure(self):
|
||||||
@@ -109,7 +111,8 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
# Result should have id, version, is_active added if not present
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
@@ -174,7 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -0,0 +1,857 @@
|
|||||||
|
"""
|
||||||
|
Tests for library agent fetching functionality in agent generator.
|
||||||
|
|
||||||
|
This test suite verifies the search-based library agent fetching,
|
||||||
|
including the combination of library and marketplace agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.agent_generator import core
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentsForGeneration:
|
||||||
|
"""Test get_library_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_agents_with_search_term(self):
|
||||||
|
"""Test that search_term is passed to the library db."""
|
||||||
|
# Create a mock agent with proper attribute values
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Email Agent"
|
||||||
|
mock_agent.description = "Sends emails"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
mock_agent.recent_executions = []
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [mock_agent]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="send email",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term="send email",
|
||||||
|
page=1,
|
||||||
|
page_size=15,
|
||||||
|
include_executions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result format
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-123"
|
||||||
|
assert result[0]["name"] == "Email Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_excludes_specified_graph_id(self):
|
||||||
|
"""Test that agents with excluded graph_id are filtered out."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-123",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 1",
|
||||||
|
description="First agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
recent_executions=[],
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
graph_id="agent-456",
|
||||||
|
graph_version=1,
|
||||||
|
name="Agent 2",
|
||||||
|
description="Second agent",
|
||||||
|
input_schema={},
|
||||||
|
output_schema={},
|
||||||
|
recent_executions=[],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
exclude_graph_id="agent-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the excluded agent is not in results
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["graph_id"] == "agent-456"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respects_max_results(self):
|
||||||
|
"""Test that max_results parameter limits the page_size."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_list:
|
||||||
|
await core.get_library_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
max_results=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_list.assert_called_once_with(
|
||||||
|
user_id="user-123",
|
||||||
|
search_term=None,
|
||||||
|
page=1,
|
||||||
|
page_size=5,
|
||||||
|
include_executions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchMarketplaceAgentsForGeneration:
|
||||||
|
"""Test search_marketplace_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_searches_marketplace_with_query(self):
|
||||||
|
"""Test that marketplace is searched with the query."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = [
|
||||||
|
MagicMock(
|
||||||
|
agent_name="Public Agent",
|
||||||
|
description="A public agent",
|
||||||
|
sub_heading="Does something useful",
|
||||||
|
creator="creator-1",
|
||||||
|
agent_graph_id="graph-123",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_graph.id = "graph-123"
|
||||||
|
mock_graph.version = 1
|
||||||
|
mock_graph.input_schema = {"type": "object"}
|
||||||
|
mock_graph.output_schema = {"type": "object"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_search,
|
||||||
|
patch(
|
||||||
|
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"graph-123": mock_graph},
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="automation",
|
||||||
|
max_results=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_search.assert_called_once_with(
|
||||||
|
search_query="automation",
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "Public Agent"
|
||||||
|
assert result[0]["graph_id"] == "graph-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_marketplace_error_gracefully(self):
|
||||||
|
"""Test that marketplace errors don't crash the function."""
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.store.db.get_store_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Marketplace unavailable"),
|
||||||
|
):
|
||||||
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
|
search_query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty list, not raise exception
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsForGeneration:
|
||||||
|
"""Test get_all_relevant_agents_for_generation function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combines_library_and_marketplace_agents(self):
|
||||||
|
"""Test that agents from both sources are combined."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "market-456",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Market Agent",
|
||||||
|
"description": "From marketplace",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test query",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Library agents should come first
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["name"] == "Library Agent"
|
||||||
|
assert result[1]["name"] == "Market Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_graph_id(self):
|
||||||
|
"""Test that marketplace agents with same graph_id as library are excluded."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "shared-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Shared Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
marketplace_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "shared-123", # Same graph_id, should be deduplicated
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Shared Agent",
|
||||||
|
"description": "From marketplace",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"graph_id": "unique-456",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Unique Agent",
|
||||||
|
"description": "Only in marketplace",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=marketplace_agents,
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared Agent from marketplace should be excluded by graph_id
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Shared Agent" in names
|
||||||
|
assert "Unique Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_disabled(self):
|
||||||
|
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="test",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_marketplace_when_no_search_query(self):
|
||||||
|
"""Test that marketplace is not searched without a search query."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "lib-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Library Agent",
|
||||||
|
"description": "From library",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_library_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=library_agents,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"search_marketplace_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_marketplace:
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query=None, # No search query
|
||||||
|
include_marketplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Marketplace should not be called without search query
|
||||||
|
mock_marketplace.assert_not_called()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractSearchTermsFromSteps:
|
||||||
|
"""Test extract_search_terms_from_steps function."""
|
||||||
|
|
||||||
|
def test_extracts_terms_from_instructions_type(self):
|
||||||
|
"""Test extraction from valid instructions decomposition result."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"description": "Send an email notification",
|
||||||
|
"block_name": "GmailSendBlock",
|
||||||
|
},
|
||||||
|
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "Send an email notification" in result
|
||||||
|
assert "GmailSendBlock" in result
|
||||||
|
assert "Fetch weather data" in result
|
||||||
|
assert "Get weather API" in result
|
||||||
|
|
||||||
|
def test_returns_empty_for_non_instructions_type(self):
|
||||||
|
"""Test that non-instructions types return empty list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions",
|
||||||
|
"questions": [{"question": "What email?"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_deduplicates_terms_case_insensitively(self):
|
||||||
|
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send Email", "name": "send email"},
|
||||||
|
{"description": "Other task"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
# Should only have one "send email" variant
|
||||||
|
email_terms = [t for t in result if "email" in t.lower()]
|
||||||
|
assert len(email_terms) == 1
|
||||||
|
|
||||||
|
def test_filters_short_terms(self):
|
||||||
|
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "ab", "action": "xyz"}, # Both too short
|
||||||
|
{"description": "Valid term here"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert "ab" not in result
|
||||||
|
assert "xyz" not in result
|
||||||
|
assert "Valid term here" in result
|
||||||
|
|
||||||
|
def test_handles_empty_steps(self):
|
||||||
|
"""Test handling of empty steps list."""
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnrichLibraryAgentsFromSteps:
|
||||||
|
"""Test enrich_library_agents_from_steps function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enriches_with_additional_agents(self):
|
||||||
|
"""Test that additional agents are found based on steps."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "new-456",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "For sending emails",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "Send email notification"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have both existing and new agents
|
||||||
|
assert len(result) == 2
|
||||||
|
names = [a["name"] for a in result]
|
||||||
|
assert "Existing Agent" in names
|
||||||
|
assert "Email Agent" in names
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_graph_id(self):
|
||||||
|
"""Test that agents with same graph_id are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns same agent
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123", # Same ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent Copy",
|
||||||
|
"description": "Same agent different name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Some action"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduplicates_by_name(self):
|
||||||
|
"""Test that agents with same name are not duplicated."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Additional search returns agent with same name but different ID
|
||||||
|
additional_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456", # Different ID
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Agent", # Same name
|
||||||
|
"description": "Different agent same name",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [{"description": "Send email"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=additional_agents,
|
||||||
|
):
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not duplicate by name
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_existing_when_no_steps(self):
|
||||||
|
"""Test that existing agents are returned when no search terms extracted."""
|
||||||
|
existing_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "existing-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Existing Agent",
|
||||||
|
"description": "Already fetched",
|
||||||
|
"input_schema": {},
|
||||||
|
"output_schema": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "clarifying_questions", # Not instructions type
|
||||||
|
"questions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return existing unchanged
|
||||||
|
assert result == existing_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_limits_search_terms_to_three(self):
|
||||||
|
"""Test that only first 3 search terms are used."""
|
||||||
|
existing_agents = []
|
||||||
|
|
||||||
|
decomposition_result = {
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": [
|
||||||
|
{"description": "First action"},
|
||||||
|
{"description": "Second action"},
|
||||||
|
{"description": "Third action"},
|
||||||
|
{"description": "Fourth action"},
|
||||||
|
{"description": "Fifth action"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_get_agents(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core,
|
||||||
|
"get_all_relevant_agents_for_generation",
|
||||||
|
side_effect=mock_get_agents,
|
||||||
|
):
|
||||||
|
await core.enrich_library_agents_from_steps(
|
||||||
|
user_id="user-123",
|
||||||
|
decomposition_result=decomposition_result,
|
||||||
|
existing_agents=existing_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only make 3 calls (limited to first 3 terms)
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractUuidsFromText:
|
||||||
|
"""Test extract_uuids_from_text function."""
|
||||||
|
|
||||||
|
def test_extracts_single_uuid(self):
|
||||||
|
"""Test extraction of a single UUID from text."""
|
||||||
|
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||||
|
|
||||||
|
def test_extracts_multiple_uuids(self):
|
||||||
|
"""Test extraction of multiple UUIDs from text."""
|
||||||
|
text = (
|
||||||
|
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||||
|
"and 22222222-2222-4222-9222-222222222222"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "11111111-1111-4111-8111-111111111111" in result
|
||||||
|
assert "22222222-2222-4222-9222-222222222222" in result
|
||||||
|
|
||||||
|
def test_deduplicates_uuids(self):
|
||||||
|
"""Test that duplicate UUIDs are deduplicated."""
|
||||||
|
text = (
|
||||||
|
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||||
|
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
)
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_normalizes_to_lowercase(self):
|
||||||
|
"""Test that UUIDs are normalized to lowercase."""
|
||||||
|
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
|
||||||
|
def test_returns_empty_for_no_uuids(self):
|
||||||
|
"""Test that empty list is returned when no UUIDs found."""
|
||||||
|
text = "Create an email agent that sends notifications"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_ignores_invalid_uuids(self):
|
||||||
|
"""Test that invalid UUID-like strings are ignored."""
|
||||||
|
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||||
|
result = core.extract_uuids_from_text(text)
|
||||||
|
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLibraryAgentById:
|
||||||
|
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_agent_when_found_by_graph_id(self):
|
||||||
|
"""Test that agent is returned when found by graph_id."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "agent-123"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Test Agent"
|
||||||
|
mock_agent.description = "Test description"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "agent-123"
|
||||||
|
assert result["name"] == "Test Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_library_agent_id(self):
|
||||||
|
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Library Agent"
|
||||||
|
mock_agent.description = "Found by library ID"
|
||||||
|
mock_agent.input_schema = {"properties": {}}
|
||||||
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None, # Not found by graph_id
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent, # Found by library ID
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["graph_id"] == "graph-456"
|
||||||
|
assert result["name"] == "Library Agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_when_not_found_by_either_method(self):
|
||||||
|
"""Test that None is returned when agent not found by either method."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=core.NotFoundError("Not found"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_none_on_exception(self):
|
||||||
|
"""Test that None is returned when exception occurs in both lookups."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("Database error"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_alias_works(self):
|
||||||
|
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||||
|
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAllRelevantAgentsWithUuids:
|
||||||
|
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_explicitly_mentioned_agents(self):
|
||||||
|
"""Test that agents mentioned by UUID are fetched directly."""
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
mock_agent.graph_version = 1
|
||||||
|
mock_agent.name = "Mentioned Agent"
|
||||||
|
mock_agent.description = "Explicitly mentioned"
|
||||||
|
mock_agent.input_schema = {}
|
||||||
|
mock_agent.output_schema = {}
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.agents = []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"get_library_agent_by_graph_id",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_agent,
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
core.library_db,
|
||||||
|
"list_library_agents",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
|
user_id="user-123",
|
||||||
|
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||||
|
include_marketplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||||
|
assert result[0].get("name") == "Mentioned Agent"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_with_context(self):
|
async def test_decompose_goal_with_context(self):
|
||||||
"""Test decomposition with additional context."""
|
"""Test decomposition with additional context enriched into description."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -119,9 +119,12 @@ class TestDecomposeGoalExternal:
|
|||||||
"Build a chatbot", context="Use Python"
|
"Build a chatbot", context="Use Python"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expected_description = (
|
||||||
|
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||||
|
)
|
||||||
mock_client.post.assert_called_once_with(
|
mock_client.post.assert_called_once_with(
|
||||||
"/api/decompose-description",
|
"/api/decompose-description",
|
||||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
json={"description": expected_description},
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -433,5 +436,139 @@ class TestGetBlocksExternal:
|
|||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLibraryAgentsPassthrough:
|
||||||
|
"""Test that library_agents are passed correctly in all requests."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset client singleton before each test."""
|
||||||
|
service._settings = None
|
||||||
|
service._client = None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in decompose goal payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-123",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Email Sender",
|
||||||
|
"description": "Sends emails",
|
||||||
|
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external(
|
||||||
|
"Send an email",
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in generate agent payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-456",
|
||||||
|
"graph_version": 2,
|
||||||
|
"name": "Data Fetcher",
|
||||||
|
"description": "Fetches data from API",
|
||||||
|
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_external(
|
||||||
|
{"steps": ["Step 1"]},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_agent_patch_passes_library_agents(self):
|
||||||
|
"""Test that library_agents are included in patch generation payload."""
|
||||||
|
library_agents = [
|
||||||
|
{
|
||||||
|
"graph_id": "agent-789",
|
||||||
|
"graph_version": 1,
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Sends Slack messages",
|
||||||
|
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||||
|
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.generate_agent_patch_external(
|
||||||
|
"Add error handling",
|
||||||
|
{"name": "Original Agent", "nodes": []},
|
||||||
|
library_agents=library_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify library_agents was passed in the payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decompose_goal_without_library_agents(self):
|
||||||
|
"""Test that decompose goal works without library_agents."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"success": True,
|
||||||
|
"type": "instructions",
|
||||||
|
"steps": ["Step 1"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
|
||||||
|
with patch.object(service, "_get_client", return_value=mock_client):
|
||||||
|
await service.decompose_goal_external("Build a workflow")
|
||||||
|
|
||||||
|
# Verify library_agents was NOT passed when not provided
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert "library_agents" not in call_args[1]["json"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -43,19 +43,24 @@ faker = Faker()
|
|||||||
# Constants for data generation limits (reduced for E2E tests)
|
# Constants for data generation limits (reduced for E2E tests)
|
||||||
NUM_USERS = 15
|
NUM_USERS = 15
|
||||||
NUM_AGENT_BLOCKS = 30
|
NUM_AGENT_BLOCKS = 30
|
||||||
MIN_GRAPHS_PER_USER = 15
|
MIN_GRAPHS_PER_USER = 25
|
||||||
MAX_GRAPHS_PER_USER = 15
|
MAX_GRAPHS_PER_USER = 25
|
||||||
MIN_NODES_PER_GRAPH = 3
|
MIN_NODES_PER_GRAPH = 3
|
||||||
MAX_NODES_PER_GRAPH = 6
|
MAX_NODES_PER_GRAPH = 6
|
||||||
MIN_PRESETS_PER_USER = 2
|
MIN_PRESETS_PER_USER = 2
|
||||||
MAX_PRESETS_PER_USER = 3
|
MAX_PRESETS_PER_USER = 3
|
||||||
MIN_AGENTS_PER_USER = 15
|
MIN_AGENTS_PER_USER = 25
|
||||||
MAX_AGENTS_PER_USER = 15
|
MAX_AGENTS_PER_USER = 25
|
||||||
MIN_EXECUTIONS_PER_GRAPH = 2
|
MIN_EXECUTIONS_PER_GRAPH = 2
|
||||||
MAX_EXECUTIONS_PER_GRAPH = 8
|
MAX_EXECUTIONS_PER_GRAPH = 8
|
||||||
MIN_REVIEWS_PER_VERSION = 2
|
MIN_REVIEWS_PER_VERSION = 2
|
||||||
MAX_REVIEWS_PER_VERSION = 5
|
MAX_REVIEWS_PER_VERSION = 5
|
||||||
|
|
||||||
|
# Guaranteed minimums for marketplace tests (deterministic)
|
||||||
|
GUARANTEED_FEATURED_AGENTS = 8
|
||||||
|
GUARANTEED_FEATURED_CREATORS = 5
|
||||||
|
GUARANTEED_TOP_AGENTS = 10
|
||||||
|
|
||||||
|
|
||||||
def get_image():
|
def get_image():
|
||||||
"""Generate a consistent image URL using picsum.photos service."""
|
"""Generate a consistent image URL using picsum.photos service."""
|
||||||
@@ -385,7 +390,7 @@ class TestDataCreator:
|
|||||||
|
|
||||||
library_agents = []
|
library_agents = []
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
num_agents = 10 # Create exactly 10 agents per user
|
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||||
|
|
||||||
# Get available graphs for this user
|
# Get available graphs for this user
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
@@ -507,14 +512,17 @@ class TestDataCreator:
|
|||||||
existing_profiles, min(num_creators, len(existing_profiles))
|
existing_profiles, min(num_creators, len(existing_profiles))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark about 50% of creators as featured (more for testing)
|
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
|
||||||
num_featured = max(2, int(num_creators * 0.5))
|
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
|
||||||
num_featured = min(
|
num_featured = min(
|
||||||
num_featured, len(selected_profiles)
|
num_featured, len(selected_profiles)
|
||||||
) # Don't exceed available profiles
|
) # Don't exceed available profiles
|
||||||
featured_profile_ids = set(
|
featured_profile_ids = set(
|
||||||
random.sample([p.id for p in selected_profiles], num_featured)
|
random.sample([p.id for p in selected_profiles], num_featured)
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
|
||||||
|
)
|
||||||
|
|
||||||
for profile in selected_profiles:
|
for profile in selected_profiles:
|
||||||
try:
|
try:
|
||||||
@@ -545,21 +553,25 @@ class TestDataCreator:
|
|||||||
return profiles
|
return profiles
|
||||||
|
|
||||||
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
||||||
"""Create test store submissions using the API function."""
|
"""Create test store submissions using the API function.
|
||||||
|
|
||||||
|
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
|
||||||
|
"""
|
||||||
print("Creating test store submissions...")
|
print("Creating test store submissions...")
|
||||||
|
|
||||||
submissions = []
|
submissions = []
|
||||||
approved_submissions = []
|
approved_submissions = []
|
||||||
|
featured_count = 0
|
||||||
|
submission_counter = 0
|
||||||
|
|
||||||
# Create a special test submission for test123@gmail.com
|
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||||
test_user = next(
|
test_user = next(
|
||||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||||
)
|
)
|
||||||
if test_user:
|
if test_user and self.agent_graphs:
|
||||||
# Special test data for consistent testing
|
|
||||||
test_submission_data = {
|
test_submission_data = {
|
||||||
"user_id": test_user["id"],
|
"user_id": test_user["id"],
|
||||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
"agent_id": self.agent_graphs[0]["id"],
|
||||||
"agent_version": 1,
|
"agent_version": 1,
|
||||||
"slug": "test-agent-submission",
|
"slug": "test-agent-submission",
|
||||||
"name": "Test Agent Submission",
|
"name": "Test Agent Submission",
|
||||||
@@ -580,37 +592,24 @@ class TestDataCreator:
|
|||||||
submissions.append(test_submission.model_dump())
|
submissions.append(test_submission.model_dump())
|
||||||
print("✅ Created special test store submission for test123@gmail.com")
|
print("✅ Created special test store submission for test123@gmail.com")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the test submission
|
# ALWAYS approve and feature the test submission
|
||||||
if test_submission.store_listing_version_id:
|
if test_submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
approved_submission = await review_store_submission(
|
||||||
if random_value < 0.4: # 40% chance to approve
|
store_listing_version_id=test_submission.store_listing_version_id,
|
||||||
approved_submission = await review_store_submission(
|
is_approved=True,
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
external_comments="Test submission approved",
|
||||||
is_approved=True,
|
internal_comments="Auto-approved test submission",
|
||||||
external_comments="Test submission approved",
|
reviewer_id=test_user["id"],
|
||||||
internal_comments="Auto-approved test submission",
|
)
|
||||||
reviewer_id=test_user["id"],
|
approved_submissions.append(approved_submission.model_dump())
|
||||||
)
|
print("✅ Approved test store submission")
|
||||||
approved_submissions.append(approved_submission.model_dump())
|
|
||||||
print("✅ Approved test store submission")
|
|
||||||
|
|
||||||
# Mark approved submission as featured
|
await prisma.storelistingversion.update(
|
||||||
await prisma.storelistingversion.update(
|
where={"id": test_submission.store_listing_version_id},
|
||||||
where={"id": test_submission.store_listing_version_id},
|
data={"isFeatured": True},
|
||||||
data={"isFeatured": True},
|
)
|
||||||
)
|
featured_count += 1
|
||||||
print("🌟 Marked test agent as FEATURED")
|
print("🌟 Marked test agent as FEATURED")
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
|
||||||
await review_store_submission(
|
|
||||||
store_listing_version_id=test_submission.store_listing_version_id,
|
|
||||||
is_approved=False,
|
|
||||||
external_comments="Test submission rejected - needs improvements",
|
|
||||||
internal_comments="Auto-rejected test submission for E2E testing",
|
|
||||||
reviewer_id=test_user["id"],
|
|
||||||
)
|
|
||||||
print("❌ Rejected test store submission")
|
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
|
||||||
print("⏳ Left test submission pending for review")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating test store submission: {e}")
|
print(f"Error creating test store submission: {e}")
|
||||||
@@ -620,7 +619,6 @@ class TestDataCreator:
|
|||||||
|
|
||||||
# Create regular submissions for all users
|
# Create regular submissions for all users
|
||||||
for user in self.users:
|
for user in self.users:
|
||||||
# Get available graphs for this specific user
|
|
||||||
user_graphs = [
|
user_graphs = [
|
||||||
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
||||||
]
|
]
|
||||||
@@ -631,18 +629,17 @@ class TestDataCreator:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create exactly 4 store submissions per user
|
|
||||||
for submission_index in range(4):
|
for submission_index in range(4):
|
||||||
graph = random.choice(user_graphs)
|
graph = random.choice(user_graphs)
|
||||||
|
submission_counter += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(
|
print(
|
||||||
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
|
f"Creating store submission for user {user['id']} with graph {graph['id']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the API function to create store submission with correct parameters
|
|
||||||
submission = await create_store_submission(
|
submission = await create_store_submission(
|
||||||
user_id=user["id"], # Must match graph's userId
|
user_id=user["id"],
|
||||||
agent_id=graph["id"],
|
agent_id=graph["id"],
|
||||||
agent_version=graph.get("version", 1),
|
agent_version=graph.get("version", 1),
|
||||||
slug=faker.slug(),
|
slug=faker.slug(),
|
||||||
@@ -651,22 +648,24 @@ class TestDataCreator:
|
|||||||
video_url=get_video_url() if random.random() < 0.3 else None,
|
video_url=get_video_url() if random.random() < 0.3 else None,
|
||||||
image_urls=[get_image() for _ in range(3)],
|
image_urls=[get_image() for _ in range(3)],
|
||||||
description=faker.text(),
|
description=faker.text(),
|
||||||
categories=[
|
categories=[get_category()],
|
||||||
get_category()
|
|
||||||
], # Single category from predefined list
|
|
||||||
changes_summary="Initial E2E test submission",
|
changes_summary="Initial E2E test submission",
|
||||||
)
|
)
|
||||||
submissions.append(submission.model_dump())
|
submissions.append(submission.model_dump())
|
||||||
print(f"✅ Created store submission: {submission.name}")
|
print(f"✅ Created store submission: {submission.name}")
|
||||||
|
|
||||||
# Randomly approve, reject, or leave pending the submission
|
|
||||||
if submission.store_listing_version_id:
|
if submission.store_listing_version_id:
|
||||||
random_value = random.random()
|
# DETERMINISTIC: First N submissions are always approved
|
||||||
if random_value < 0.4: # 40% chance to approve
|
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||||
try:
|
should_approve = (
|
||||||
# Pick a random user as the reviewer (admin)
|
submission_counter <= GUARANTEED_TOP_AGENTS
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
or random.random() < 0.4
|
||||||
|
)
|
||||||
|
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
|
||||||
|
|
||||||
|
if should_approve:
|
||||||
|
try:
|
||||||
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
approved_submission = await review_store_submission(
|
approved_submission = await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=True,
|
is_approved=True,
|
||||||
@@ -681,16 +680,7 @@ class TestDataCreator:
|
|||||||
f"✅ Approved store submission: {submission.name}"
|
f"✅ Approved store submission: {submission.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark some agents as featured during creation (30% chance)
|
if should_feature:
|
||||||
# More likely for creators and first submissions
|
|
||||||
is_creator = user["id"] in [
|
|
||||||
p.get("userId") for p in self.profiles
|
|
||||||
]
|
|
||||||
feature_chance = (
|
|
||||||
0.5 if is_creator else 0.2
|
|
||||||
) # 50% for creators, 20% for others
|
|
||||||
|
|
||||||
if random.random() < feature_chance:
|
|
||||||
try:
|
try:
|
||||||
await prisma.storelistingversion.update(
|
await prisma.storelistingversion.update(
|
||||||
where={
|
where={
|
||||||
@@ -698,8 +688,25 @@ class TestDataCreator:
|
|||||||
},
|
},
|
||||||
data={"isFeatured": True},
|
data={"isFeatured": True},
|
||||||
)
|
)
|
||||||
|
featured_count += 1
|
||||||
print(
|
print(
|
||||||
f"🌟 Marked agent as FEATURED: {submission.name}"
|
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Could not mark submission as featured: {e}"
|
||||||
|
)
|
||||||
|
elif random.random() < 0.2:
|
||||||
|
try:
|
||||||
|
await prisma.storelistingversion.update(
|
||||||
|
where={
|
||||||
|
"id": submission.store_listing_version_id
|
||||||
|
},
|
||||||
|
data={"isFeatured": True},
|
||||||
|
)
|
||||||
|
featured_count += 1
|
||||||
|
print(
|
||||||
|
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
@@ -710,11 +717,9 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not approve submission {submission.name}: {e}"
|
f"Warning: Could not approve submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
elif random.random() < 0.5:
|
||||||
try:
|
try:
|
||||||
# Pick a random user as the reviewer (admin)
|
|
||||||
reviewer_id = random.choice(self.users)["id"]
|
reviewer_id = random.choice(self.users)["id"]
|
||||||
|
|
||||||
await review_store_submission(
|
await review_store_submission(
|
||||||
store_listing_version_id=submission.store_listing_version_id,
|
store_listing_version_id=submission.store_listing_version_id,
|
||||||
is_approved=False,
|
is_approved=False,
|
||||||
@@ -729,7 +734,7 @@ class TestDataCreator:
|
|||||||
print(
|
print(
|
||||||
f"Warning: Could not reject submission {submission.name}: {e}"
|
f"Warning: Could not reject submission {submission.name}: {e}"
|
||||||
)
|
)
|
||||||
else: # 30% chance to leave pending (70% to 100%)
|
else:
|
||||||
print(
|
print(
|
||||||
f"⏳ Left submission pending for review: {submission.name}"
|
f"⏳ Left submission pending for review: {submission.name}"
|
||||||
)
|
)
|
||||||
@@ -743,9 +748,13 @@ class TestDataCreator:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print("\n📊 Store Submissions Summary:")
|
||||||
|
print(f" Created: {len(submissions)}")
|
||||||
|
print(f" Approved: {len(approved_submissions)}")
|
||||||
print(
|
print(
|
||||||
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
|
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store_submissions = submissions
|
self.store_submissions = submissions
|
||||||
return submissions
|
return submissions
|
||||||
|
|
||||||
@@ -825,12 +834,15 @@ class TestDataCreator:
|
|||||||
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
||||||
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
||||||
print(f"✅ Library agents created: {len(self.library_agents)}")
|
print(f"✅ Library agents created: {len(self.library_agents)}")
|
||||||
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
|
print(f"✅ Creator profiles updated: {len(self.profiles)}")
|
||||||
print(
|
print(f"✅ Store submissions created: {len(self.store_submissions)}")
|
||||||
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
|
|
||||||
)
|
|
||||||
print(f"✅ API keys created: {len(self.api_keys)}")
|
print(f"✅ API keys created: {len(self.api_keys)}")
|
||||||
print(f"✅ Presets created: {len(self.presets)}")
|
print(f"✅ Presets created: {len(self.presets)}")
|
||||||
|
print("\n🎯 Deterministic Guarantees:")
|
||||||
|
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
|
||||||
|
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
|
||||||
|
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
|
||||||
|
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
|
||||||
print("\n🚀 Your E2E test database is ready to use!")
|
print("\n🚀 Your E2E test database is ready to use!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -857,7 +857,7 @@ export const CustomNode = React.memo(
|
|||||||
})();
|
})();
|
||||||
|
|
||||||
const hasAdvancedFields =
|
const hasAdvancedFields =
|
||||||
data.inputSchema &&
|
data.inputSchema?.properties &&
|
||||||
Object.entries(data.inputSchema.properties).some(([key, value]) => {
|
Object.entries(data.inputSchema.properties).some(([key, value]) => {
|
||||||
return (
|
return (
|
||||||
value.advanced === true && !data.inputSchema.required?.includes(key)
|
value.advanced === true && !data.inputSchema.required?.includes(key)
|
||||||
|
|||||||
@@ -7981,6 +7981,25 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"new_output": { "type": "boolean", "title": "New Output" },
|
"new_output": { "type": "boolean", "title": "New Output" },
|
||||||
|
"execution_count": {
|
||||||
|
"type": "integer",
|
||||||
|
"title": "Execution Count",
|
||||||
|
"default": 0
|
||||||
|
},
|
||||||
|
"success_rate": {
|
||||||
|
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||||
|
"title": "Success Rate"
|
||||||
|
},
|
||||||
|
"avg_correctness_score": {
|
||||||
|
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||||
|
"title": "Avg Correctness Score"
|
||||||
|
},
|
||||||
|
"recent_executions": {
|
||||||
|
"items": { "$ref": "#/components/schemas/RecentExecution" },
|
||||||
|
"type": "array",
|
||||||
|
"title": "Recent Executions",
|
||||||
|
"description": "List of recent executions with status, score, and summary"
|
||||||
|
},
|
||||||
"can_access_graph": {
|
"can_access_graph": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"title": "Can Access Graph"
|
"title": "Can Access Graph"
|
||||||
@@ -9374,6 +9393,23 @@
|
|||||||
"required": ["providers", "pagination"],
|
"required": ["providers", "pagination"],
|
||||||
"title": "ProviderResponse"
|
"title": "ProviderResponse"
|
||||||
},
|
},
|
||||||
|
"RecentExecution": {
|
||||||
|
"properties": {
|
||||||
|
"status": { "type": "string", "title": "Status" },
|
||||||
|
"correctness_score": {
|
||||||
|
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||||
|
"title": "Correctness Score"
|
||||||
|
},
|
||||||
|
"activity_summary": {
|
||||||
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
|
"title": "Activity Summary"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"required": ["status"],
|
||||||
|
"title": "RecentExecution",
|
||||||
|
"description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics."
|
||||||
|
},
|
||||||
"RefundRequest": {
|
"RefundRequest": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
@@ -9797,7 +9833,8 @@
|
|||||||
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
||||||
"description": { "type": "string", "title": "Description" },
|
"description": { "type": "string", "title": "Description" },
|
||||||
"runs": { "type": "integer", "title": "Runs" },
|
"runs": { "type": "integer", "title": "Runs" },
|
||||||
"rating": { "type": "number", "title": "Rating" }
|
"rating": { "type": "number", "title": "Rating" },
|
||||||
|
"agent_graph_id": { "type": "string", "title": "Agent Graph Id" }
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@@ -9809,7 +9846,8 @@
|
|||||||
"sub_heading",
|
"sub_heading",
|
||||||
"description",
|
"description",
|
||||||
"runs",
|
"runs",
|
||||||
"rating"
|
"rating",
|
||||||
|
"agent_graph_id"
|
||||||
],
|
],
|
||||||
"title": "StoreAgent"
|
"title": "StoreAgent"
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ export function ChatInput({
|
|||||||
isStreaming,
|
isStreaming,
|
||||||
value,
|
value,
|
||||||
baseHandleKeyDown,
|
baseHandleKeyDown,
|
||||||
|
inputId,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ interface Args {
|
|||||||
isStreaming?: boolean;
|
isStreaming?: boolean;
|
||||||
value: string;
|
value: string;
|
||||||
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
||||||
|
inputId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useVoiceRecording({
|
export function useVoiceRecording({
|
||||||
@@ -23,6 +24,7 @@ export function useVoiceRecording({
|
|||||||
isStreaming = false,
|
isStreaming = false,
|
||||||
value,
|
value,
|
||||||
baseHandleKeyDown,
|
baseHandleKeyDown,
|
||||||
|
inputId,
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [isRecording, setIsRecording] = useState(false);
|
const [isRecording, setIsRecording] = useState(false);
|
||||||
const [isTranscribing, setIsTranscribing] = useState(false);
|
const [isTranscribing, setIsTranscribing] = useState(false);
|
||||||
@@ -103,7 +105,7 @@ export function useVoiceRecording({
|
|||||||
setIsTranscribing(false);
|
setIsTranscribing(false);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[handleTranscription],
|
[handleTranscription, inputId],
|
||||||
);
|
);
|
||||||
|
|
||||||
const stopRecording = useCallback(() => {
|
const stopRecording = useCallback(() => {
|
||||||
@@ -201,6 +203,15 @@ export function useVoiceRecording({
|
|||||||
}
|
}
|
||||||
}, [error, toast]);
|
}, [error, toast]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!isTranscribing && inputId) {
|
||||||
|
const inputElement = document.getElementById(inputId);
|
||||||
|
if (inputElement) {
|
||||||
|
inputElement.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [isTranscribing, inputId]);
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
const handleKeyDown = useCallback(
|
||||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
if (event.key === " " && !value.trim() && !isTranscribing) {
|
if (event.key === " " && !value.trim() && !isTranscribing) {
|
||||||
|
|||||||
@@ -156,11 +156,19 @@ export function ChatMessage({
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isClarificationNeeded && message.type === "clarification_needed") {
|
if (isClarificationNeeded && message.type === "clarification_needed") {
|
||||||
|
const hasUserReplyAfter =
|
||||||
|
index >= 0 &&
|
||||||
|
messages
|
||||||
|
.slice(index + 1)
|
||||||
|
.some((m) => m.type === "message" && m.role === "user");
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ClarificationQuestionsWidget
|
<ClarificationQuestionsWidget
|
||||||
questions={message.questions}
|
questions={message.questions}
|
||||||
message={message.message}
|
message={message.message}
|
||||||
|
sessionId={message.sessionId}
|
||||||
onSubmitAnswers={handleClarificationAnswers}
|
onSubmitAnswers={handleClarificationAnswers}
|
||||||
|
isAnswered={hasUserReplyAfter}
|
||||||
className={className}
|
className={className}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { Input } from "@/components/atoms/Input/Input";
|
|||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { CheckCircleIcon, QuestionIcon } from "@phosphor-icons/react";
|
import { CheckCircleIcon, QuestionIcon } from "@phosphor-icons/react";
|
||||||
import { useState } from "react";
|
import { useState, useEffect, useRef } from "react";
|
||||||
|
|
||||||
export interface ClarifyingQuestion {
|
export interface ClarifyingQuestion {
|
||||||
question: string;
|
question: string;
|
||||||
@@ -17,39 +17,96 @@ export interface ClarifyingQuestion {
|
|||||||
interface Props {
|
interface Props {
|
||||||
questions: ClarifyingQuestion[];
|
questions: ClarifyingQuestion[];
|
||||||
message: string;
|
message: string;
|
||||||
|
sessionId?: string;
|
||||||
onSubmitAnswers: (answers: Record<string, string>) => void;
|
onSubmitAnswers: (answers: Record<string, string>) => void;
|
||||||
onCancel?: () => void;
|
onCancel?: () => void;
|
||||||
|
isAnswered?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getStorageKey(sessionId?: string): string | null {
|
||||||
|
if (!sessionId) return null;
|
||||||
|
return `clarification_answers_${sessionId}`;
|
||||||
|
}
|
||||||
|
|
||||||
export function ClarificationQuestionsWidget({
|
export function ClarificationQuestionsWidget({
|
||||||
questions,
|
questions,
|
||||||
message,
|
message,
|
||||||
|
sessionId,
|
||||||
onSubmitAnswers,
|
onSubmitAnswers,
|
||||||
onCancel,
|
onCancel,
|
||||||
|
isAnswered = false,
|
||||||
className,
|
className,
|
||||||
}: Props) {
|
}: Props) {
|
||||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||||
const [isSubmitted, setIsSubmitted] = useState(false);
|
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||||
|
const lastSessionIdRef = useRef<string | undefined>(undefined);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const storageKey = getStorageKey(sessionId);
|
||||||
|
if (!storageKey) {
|
||||||
|
setAnswers({});
|
||||||
|
setIsSubmitted(false);
|
||||||
|
lastSessionIdRef.current = sessionId;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem(storageKey);
|
||||||
|
if (saved) {
|
||||||
|
const parsed = JSON.parse(saved) as Record<string, string>;
|
||||||
|
setAnswers(parsed);
|
||||||
|
} else {
|
||||||
|
setAnswers({});
|
||||||
|
}
|
||||||
|
setIsSubmitted(false);
|
||||||
|
} catch {
|
||||||
|
setAnswers({});
|
||||||
|
setIsSubmitted(false);
|
||||||
|
}
|
||||||
|
lastSessionIdRef.current = sessionId;
|
||||||
|
}, [sessionId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (lastSessionIdRef.current !== sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const storageKey = getStorageKey(sessionId);
|
||||||
|
if (!storageKey) return;
|
||||||
|
|
||||||
|
const hasAnswers = Object.values(answers).some((v) => v.trim());
|
||||||
|
try {
|
||||||
|
if (hasAnswers) {
|
||||||
|
localStorage.setItem(storageKey, JSON.stringify(answers));
|
||||||
|
} else {
|
||||||
|
localStorage.removeItem(storageKey);
|
||||||
|
}
|
||||||
|
} catch {}
|
||||||
|
}, [answers, sessionId]);
|
||||||
|
|
||||||
function handleAnswerChange(keyword: string, value: string) {
|
function handleAnswerChange(keyword: string, value: string) {
|
||||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleSubmit() {
|
function handleSubmit() {
|
||||||
// Check if all questions are answered
|
|
||||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||||
if (!allAnswered) {
|
if (!allAnswered) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setIsSubmitted(true);
|
setIsSubmitted(true);
|
||||||
onSubmitAnswers(answers);
|
onSubmitAnswers(answers);
|
||||||
|
|
||||||
|
const storageKey = getStorageKey(sessionId);
|
||||||
|
try {
|
||||||
|
if (storageKey) {
|
||||||
|
localStorage.removeItem(storageKey);
|
||||||
|
}
|
||||||
|
} catch {}
|
||||||
}
|
}
|
||||||
|
|
||||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||||
|
|
||||||
// Show submitted state after answers are submitted
|
if (isAnswered || isSubmitted) {
|
||||||
if (isSubmitted) {
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ export function getErrorMessage(result: unknown): string {
|
|||||||
}
|
}
|
||||||
if (typeof result === "object" && result !== null) {
|
if (typeof result === "object" && result !== null) {
|
||||||
const response = result as Record<string, unknown>;
|
const response = result as Record<string, unknown>;
|
||||||
if (response.error) return stripInternalReasoning(String(response.error));
|
|
||||||
if (response.message)
|
if (response.message)
|
||||||
return stripInternalReasoning(String(response.message));
|
return stripInternalReasoning(String(response.message));
|
||||||
|
if (response.error) return stripInternalReasoning(String(response.error));
|
||||||
}
|
}
|
||||||
return "An error occurred";
|
return "An error occurred";
|
||||||
}
|
}
|
||||||
@@ -363,8 +363,8 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
|||||||
|
|
||||||
case "error":
|
case "error":
|
||||||
const errorMsg =
|
const errorMsg =
|
||||||
(response.error as string) || response.message || "An error occurred";
|
(response.message as string) || response.error || "An error occurred";
|
||||||
return `Error: ${errorMsg}`;
|
return stripInternalReasoning(String(errorMsg));
|
||||||
|
|
||||||
case "no_results":
|
case "no_results":
|
||||||
const suggestions = (response.suggestions as string[]) || [];
|
const suggestions = (response.suggestions as string[]) || [];
|
||||||
|
|||||||
@@ -59,12 +59,13 @@ test.describe("Library", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
test("pagination works correctly", async ({ page }, testInfo) => {
|
test("pagination works correctly", async ({ page }, testInfo) => {
|
||||||
test.setTimeout(testInfo.timeout * 3); // Increase timeout for pagination operations
|
test.setTimeout(testInfo.timeout * 3);
|
||||||
await page.goto("/library");
|
await page.goto("/library");
|
||||||
|
|
||||||
|
const PAGE_SIZE = 20;
|
||||||
const paginationResult = await libraryPage.testPagination();
|
const paginationResult = await libraryPage.testPagination();
|
||||||
|
|
||||||
if (paginationResult.initialCount >= 10) {
|
if (paginationResult.initialCount >= PAGE_SIZE) {
|
||||||
expect(paginationResult.finalCount).toBeGreaterThanOrEqual(
|
expect(paginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||||
paginationResult.initialCount,
|
paginationResult.initialCount,
|
||||||
);
|
);
|
||||||
@@ -133,7 +134,10 @@ test.describe("Library", () => {
|
|||||||
test.expect(clearedSearchValue).toBe("");
|
test.expect(clearedSearchValue).toBe("");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("pagination while searching works correctly", async ({ page }) => {
|
test("pagination while searching works correctly", async ({
|
||||||
|
page,
|
||||||
|
}, testInfo) => {
|
||||||
|
test.setTimeout(testInfo.timeout * 3);
|
||||||
await page.goto("/library");
|
await page.goto("/library");
|
||||||
|
|
||||||
const allAgents = await libraryPage.getAgents();
|
const allAgents = await libraryPage.getAgents();
|
||||||
@@ -152,9 +156,10 @@ test.describe("Library", () => {
|
|||||||
);
|
);
|
||||||
expect(matchingResults.length).toEqual(initialSearchResults.length);
|
expect(matchingResults.length).toEqual(initialSearchResults.length);
|
||||||
|
|
||||||
|
const PAGE_SIZE = 20;
|
||||||
const searchPaginationResult = await libraryPage.testPagination();
|
const searchPaginationResult = await libraryPage.testPagination();
|
||||||
|
|
||||||
if (searchPaginationResult.initialCount >= 10) {
|
if (searchPaginationResult.initialCount >= PAGE_SIZE) {
|
||||||
expect(searchPaginationResult.finalCount).toBeGreaterThanOrEqual(
|
expect(searchPaginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||||
searchPaginationResult.initialCount,
|
searchPaginationResult.initialCount,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -69,9 +69,12 @@ test.describe("Marketplace Creator Page – Basic Functionality", () => {
|
|||||||
await marketplacePage.getFirstCreatorProfile(page);
|
await marketplacePage.getFirstCreatorProfile(page);
|
||||||
await firstCreatorProfile.click();
|
await firstCreatorProfile.click();
|
||||||
await page.waitForURL("**/marketplace/creator/**");
|
await page.waitForURL("**/marketplace/creator/**");
|
||||||
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
|
|
||||||
const firstAgent = page
|
const firstAgent = page
|
||||||
.locator('[data-testid="store-card"]:visible')
|
.locator('[data-testid="store-card"]:visible')
|
||||||
.first();
|
.first();
|
||||||
|
await firstAgent.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
|
||||||
await firstAgent.click();
|
await firstAgent.click();
|
||||||
await page.waitForURL("**/marketplace/agent/**");
|
await page.waitForURL("**/marketplace/agent/**");
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ test.describe("Marketplace – Basic Functionality", () => {
|
|||||||
|
|
||||||
const firstFeaturedAgent =
|
const firstFeaturedAgent =
|
||||||
await marketplacePage.getFirstFeaturedAgent(page);
|
await marketplacePage.getFirstFeaturedAgent(page);
|
||||||
await firstFeaturedAgent.waitFor({ state: "visible" });
|
|
||||||
await firstFeaturedAgent.click();
|
await firstFeaturedAgent.click();
|
||||||
await page.waitForURL("**/marketplace/agent/**");
|
await page.waitForURL("**/marketplace/agent/**");
|
||||||
await matchesUrl(page, /\/marketplace\/agent\/.+/);
|
await matchesUrl(page, /\/marketplace\/agent\/.+/);
|
||||||
@@ -116,7 +115,15 @@ test.describe("Marketplace – Basic Functionality", () => {
|
|||||||
const searchTerm = page.getByText("DummyInput").first();
|
const searchTerm = page.getByText("DummyInput").first();
|
||||||
await isVisible(searchTerm);
|
await isVisible(searchTerm);
|
||||||
|
|
||||||
await page.waitForTimeout(10000);
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
|
|
||||||
|
await page
|
||||||
|
.waitForFunction(
|
||||||
|
() =>
|
||||||
|
document.querySelectorAll('[data-testid="store-card"]').length > 0,
|
||||||
|
{ timeout: 15000 },
|
||||||
|
)
|
||||||
|
.catch(() => console.log("No search results appeared within timeout"));
|
||||||
|
|
||||||
const results = await marketplacePage.getSearchResultsCount(page);
|
const results = await marketplacePage.getSearchResultsCount(page);
|
||||||
expect(results).toBeGreaterThan(0);
|
expect(results).toBeGreaterThan(0);
|
||||||
|
|||||||
@@ -300,21 +300,27 @@ export class LibraryPage extends BasePage {
|
|||||||
async scrollToLoadMore(): Promise<void> {
|
async scrollToLoadMore(): Promise<void> {
|
||||||
console.log(`scrolling to load more agents`);
|
console.log(`scrolling to load more agents`);
|
||||||
|
|
||||||
// Get initial agent count
|
const initialCount = await this.getAgentCountByListLength();
|
||||||
const initialCount = await this.getAgentCount();
|
console.log(`Initial agent count (DOM cards): ${initialCount}`);
|
||||||
console.log(`Initial agent count: ${initialCount}`);
|
|
||||||
|
|
||||||
// Scroll down to trigger pagination
|
|
||||||
await this.scrollToBottom();
|
await this.scrollToBottom();
|
||||||
|
|
||||||
// Wait for potential new agents to load
|
await this.page
|
||||||
await this.page.waitForTimeout(2000);
|
.waitForLoadState("networkidle", { timeout: 10000 })
|
||||||
|
.catch(() => console.log("Network idle timeout, continuing..."));
|
||||||
|
|
||||||
// Check if more agents loaded
|
await this.page
|
||||||
const newCount = await this.getAgentCount();
|
.waitForFunction(
|
||||||
console.log(`New agent count after scroll: ${newCount}`);
|
(prevCount) =>
|
||||||
|
document.querySelectorAll('[data-testid="library-agent-card"]')
|
||||||
|
.length > prevCount,
|
||||||
|
initialCount,
|
||||||
|
{ timeout: 5000 },
|
||||||
|
)
|
||||||
|
.catch(() => {});
|
||||||
|
|
||||||
return;
|
const newCount = await this.getAgentCountByListLength();
|
||||||
|
console.log(`New agent count after scroll (DOM cards): ${newCount}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
async testPagination(): Promise<{
|
async testPagination(): Promise<{
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ export class MarketplacePage extends BasePage {
|
|||||||
|
|
||||||
async goto(page: Page) {
|
async goto(page: Page) {
|
||||||
await page.goto("/marketplace");
|
await page.goto("/marketplace");
|
||||||
|
await page.waitForLoadState("networkidle").catch(() => {});
|
||||||
}
|
}
|
||||||
|
|
||||||
async getMarketplaceTitle(page: Page) {
|
async getMarketplaceTitle(page: Page) {
|
||||||
@@ -109,16 +110,24 @@ export class MarketplacePage extends BasePage {
|
|||||||
|
|
||||||
async getFirstFeaturedAgent(page: Page) {
|
async getFirstFeaturedAgent(page: Page) {
|
||||||
const { getId } = getSelectors(page);
|
const { getId } = getSelectors(page);
|
||||||
return getId("featured-store-card").first();
|
const card = getId("featured-store-card").first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getFirstTopAgent() {
|
async getFirstTopAgent() {
|
||||||
return this.page.locator('[data-testid="store-card"]:visible').first();
|
const card = this.page
|
||||||
|
.locator('[data-testid="store-card"]:visible')
|
||||||
|
.first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getFirstCreatorProfile(page: Page) {
|
async getFirstCreatorProfile(page: Page) {
|
||||||
const { getId } = getSelectors(page);
|
const { getId } = getSelectors(page);
|
||||||
return getId("creator-card").first();
|
const card = getId("creator-card").first();
|
||||||
|
await card.waitFor({ state: "visible", timeout: 30000 });
|
||||||
|
return card;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSearchResultsCount(page: Page) {
|
async getSearchResultsCount(page: Page) {
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
extend-ignore = E203
|
|
||||||
exclude =
|
exclude =
|
||||||
.tox,
|
.tox,
|
||||||
__pycache__,
|
__pycache__,
|
||||||
*.pyc,
|
*.pyc,
|
||||||
.env,
|
.env
|
||||||
venv*,
|
venv*/*,
|
||||||
.venv,
|
.venv/*,
|
||||||
reports,
|
reports/*,
|
||||||
dist,
|
dist/*,
|
||||||
data,
|
data/*,
|
||||||
.benchmark_workspaces,
|
|
||||||
.autogpt,
|
|
||||||
|
|||||||
@@ -1,291 +0,0 @@
|
|||||||
# CLAUDE.md
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
AutoGPT Classic is an experimental, **unsupported** project demonstrating autonomous GPT-4 operation. Dependencies will not be updated, and the codebase contains known vulnerabilities. This is preserved for educational/historical purposes.
|
|
||||||
|
|
||||||
## Repository Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
classic/
|
|
||||||
├── pyproject.toml # Single consolidated Poetry project
|
|
||||||
├── poetry.lock # Single lock file
|
|
||||||
├── forge/
|
|
||||||
│ └── forge/ # Core agent framework package
|
|
||||||
├── original_autogpt/
|
|
||||||
│ └── autogpt/ # AutoGPT agent package
|
|
||||||
├── direct_benchmark/
|
|
||||||
│ └── direct_benchmark/ # Benchmark harness package
|
|
||||||
└── benchmark/ # Challenge definitions (data, not code)
|
|
||||||
```
|
|
||||||
|
|
||||||
All packages are managed by a single `pyproject.toml` at the classic/ root.
|
|
||||||
|
|
||||||
## Common Commands
|
|
||||||
|
|
||||||
### Setup & Install
|
|
||||||
```bash
|
|
||||||
# Install everything from classic/ directory
|
|
||||||
cd classic
|
|
||||||
poetry install
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running Agents
|
|
||||||
```bash
|
|
||||||
# Run forge agent
|
|
||||||
poetry run python -m forge
|
|
||||||
|
|
||||||
# Run original autogpt server
|
|
||||||
poetry run serve --debug
|
|
||||||
|
|
||||||
# Run autogpt CLI
|
|
||||||
poetry run autogpt
|
|
||||||
```
|
|
||||||
|
|
||||||
Agents run on `http://localhost:8000` by default.
|
|
||||||
|
|
||||||
### Benchmarking
|
|
||||||
```bash
|
|
||||||
# Run benchmarks
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
|
|
||||||
# Run specific strategies and models
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo \
|
|
||||||
--models claude \
|
|
||||||
--parallel 4
|
|
||||||
|
|
||||||
# Run a single test
|
|
||||||
poetry run direct-benchmark run --tests ReadFile
|
|
||||||
|
|
||||||
# List available commands
|
|
||||||
poetry run direct-benchmark --help
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
```bash
|
|
||||||
poetry run pytest # All tests
|
|
||||||
poetry run pytest forge/tests/ # Forge tests only
|
|
||||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
|
||||||
poetry run pytest -k test_name # Single test by name
|
|
||||||
poetry run pytest path/to/test.py # Specific test file
|
|
||||||
poetry run pytest --cov # With coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
### Linting & Formatting
|
|
||||||
|
|
||||||
Run from the classic/ directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Format everything (recommended to run together)
|
|
||||||
poetry run black . && poetry run isort .
|
|
||||||
|
|
||||||
# Check formatting (CI-style, no changes)
|
|
||||||
poetry run black --check . && poetry run isort --check-only .
|
|
||||||
|
|
||||||
# Lint
|
|
||||||
poetry run flake8 # Style linting
|
|
||||||
|
|
||||||
# Type check
|
|
||||||
poetry run pyright # Type checking (some errors are expected in infrastructure code)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: Always run linters over the entire directory, not specific files, for best results.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Forge (Core Framework)
|
|
||||||
The `forge` package is the foundation that other components depend on:
|
|
||||||
- `forge/agent/` - Agent implementation and protocols
|
|
||||||
- `forge/llm/` - Multi-provider LLM integrations (OpenAI, Anthropic, Groq, LiteLLM)
|
|
||||||
- `forge/components/` - Reusable agent components
|
|
||||||
- `forge/file_storage/` - File system abstraction
|
|
||||||
- `forge/config/` - Configuration management
|
|
||||||
|
|
||||||
### Original AutoGPT
|
|
||||||
- `original_autogpt/autogpt/app/` - CLI application entry points
|
|
||||||
- `original_autogpt/autogpt/agents/` - Agent implementations
|
|
||||||
- `original_autogpt/autogpt/agent_factory/` - Agent creation logic
|
|
||||||
|
|
||||||
### Direct Benchmark
|
|
||||||
Benchmark harness for testing agent performance:
|
|
||||||
- `direct_benchmark/direct_benchmark/` - CLI and harness code
|
|
||||||
- `benchmark/agbenchmark/challenges/` - Test cases organized by category (code, retrieval, data, etc.)
|
|
||||||
- Reports generated in `direct_benchmark/reports/`
|
|
||||||
|
|
||||||
### Package Structure
|
|
||||||
All three packages are included in a single Poetry project. Imports are fully qualified:
|
|
||||||
- `from forge.agent.base import BaseAgent`
|
|
||||||
- `from autogpt.agents.agent import Agent`
|
|
||||||
- `from direct_benchmark.harness import BenchmarkHarness`
|
|
||||||
|
|
||||||
## Code Style
|
|
||||||
|
|
||||||
- Python 3.12 target
|
|
||||||
- Line length: 88 characters (Black default)
|
|
||||||
- Black for formatting, isort for imports (profile="black")
|
|
||||||
- Type hints with Pyright checking
|
|
||||||
|
|
||||||
## Testing Patterns
|
|
||||||
|
|
||||||
- Async support via pytest-asyncio
|
|
||||||
- Fixtures defined in `conftest.py` files provide: `tmp_project_root`, `storage`, `config`, `llm_provider`, `agent`
|
|
||||||
- Tests requiring API keys (OPENAI_API_KEY, ANTHROPIC_API_KEY) will skip if not set
|
|
||||||
|
|
||||||
## Environment Setup
|
|
||||||
|
|
||||||
Copy `.env.example` to `.env` in the relevant directory and add your API keys:
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your OPENAI_API_KEY, etc.
|
|
||||||
```
|
|
||||||
|
|
||||||
## Workspaces
|
|
||||||
|
|
||||||
Agents operate within a **workspace** - a directory containing all agent data and files. The workspace root defaults to the current working directory.
|
|
||||||
|
|
||||||
### Workspace Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
{workspace}/
|
|
||||||
├── .autogpt/
|
|
||||||
│ ├── autogpt.yaml # Workspace-level permissions
|
|
||||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
|
||||||
│ └── agents/
|
|
||||||
│ └── AutoGPT-{agent_id}/
|
|
||||||
│ ├── state.json # Agent profile, directives, action history
|
|
||||||
│ ├── permissions.yaml # Agent-specific permission overrides
|
|
||||||
│ └── workspace/ # Agent's sandboxed working directory
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key Concepts
|
|
||||||
|
|
||||||
- **Multiple agents** can coexist in the same workspace (each gets its own subdirectory)
|
|
||||||
- **File access** is sandboxed to the agent's `workspace/` directory by default
|
|
||||||
- **State persistence** - agent state saves to `state.json` and survives across sessions
|
|
||||||
- **Storage backends** - supports local filesystem, S3, and GCS (via `FILE_STORAGE_BACKEND` env var)
|
|
||||||
|
|
||||||
### Specifying a Workspace
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Default: uses current directory
|
|
||||||
cd /path/to/my/project && poetry run autogpt
|
|
||||||
|
|
||||||
# Or specify explicitly via CLI (if supported)
|
|
||||||
poetry run autogpt --workspace /path/to/workspace
|
|
||||||
```
|
|
||||||
|
|
||||||
## Settings Location
|
|
||||||
|
|
||||||
Configuration uses a **layered system** with three levels (in order of precedence):
|
|
||||||
|
|
||||||
### 1. Environment Variables (Global)
|
|
||||||
|
|
||||||
Loaded from `.env` file in the working directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Required
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Optional LLM settings
|
|
||||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
|
||||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
|
||||||
EMBEDDING_MODEL=text-embedding-3-small
|
|
||||||
|
|
||||||
# Optional search providers (for web search component)
|
|
||||||
TAVILY_API_KEY=tvly-...
|
|
||||||
SERPER_API_KEY=...
|
|
||||||
GOOGLE_API_KEY=...
|
|
||||||
GOOGLE_CUSTOM_SEARCH_ENGINE_ID=...
|
|
||||||
|
|
||||||
# Optional infrastructure
|
|
||||||
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR
|
|
||||||
DATABASE_STRING=sqlite:///agent.db # Agent Protocol database
|
|
||||||
PORT=8000 # Server port
|
|
||||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Workspace Settings (`{workspace}/.autogpt/autogpt.yaml`)
|
|
||||||
|
|
||||||
Workspace-wide permissions that apply to **all agents** in this workspace:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- read_file({workspace}/**)
|
|
||||||
- write_to_file({workspace}/**)
|
|
||||||
- list_folder({workspace}/**)
|
|
||||||
- web_search(*)
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- read_file(**.env)
|
|
||||||
- read_file(**.env.*)
|
|
||||||
- read_file(**.key)
|
|
||||||
- read_file(**.pem)
|
|
||||||
- execute_shell(rm -rf:*)
|
|
||||||
- execute_shell(sudo:*)
|
|
||||||
```
|
|
||||||
|
|
||||||
Auto-generated with sensible defaults if missing.
|
|
||||||
|
|
||||||
### 3. Agent Settings (`{workspace}/.autogpt/agents/{id}/permissions.yaml`)
|
|
||||||
|
|
||||||
Agent-specific permission overrides:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- execute_python(*)
|
|
||||||
- web_search(*)
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- execute_shell(*)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Permissions
|
|
||||||
|
|
||||||
The permission system uses **pattern matching** with a **first-match-wins** evaluation order.
|
|
||||||
|
|
||||||
### Permission Check Order
|
|
||||||
|
|
||||||
1. Agent deny list → **Block**
|
|
||||||
2. Workspace deny list → **Block**
|
|
||||||
3. Agent allow list → **Allow**
|
|
||||||
4. Workspace allow list → **Allow**
|
|
||||||
5. Session denied list → **Block** (commands denied during this session)
|
|
||||||
6. **Prompt user** → Interactive approval (if in interactive mode)
|
|
||||||
|
|
||||||
### Pattern Syntax
|
|
||||||
|
|
||||||
Format: `command_name(glob_pattern)`
|
|
||||||
|
|
||||||
| Pattern | Description |
|
|
||||||
|---------|-------------|
|
|
||||||
| `read_file({workspace}/**)` | Read any file in workspace (recursive) |
|
|
||||||
| `write_to_file({workspace}/*.txt)` | Write only .txt files in workspace root |
|
|
||||||
| `execute_shell(python:**)` | Execute Python commands only |
|
|
||||||
| `execute_shell(git:*)` | Execute any git command |
|
|
||||||
| `web_search(*)` | Allow all web searches |
|
|
||||||
|
|
||||||
Special tokens:
|
|
||||||
- `{workspace}` - Replaced with actual workspace path
|
|
||||||
- `**` - Matches any path including `/`
|
|
||||||
- `*` - Matches any characters except `/`
|
|
||||||
|
|
||||||
### Interactive Approval Scopes
|
|
||||||
|
|
||||||
When prompted for permission, users can choose:
|
|
||||||
|
|
||||||
| Scope | Effect |
|
|
||||||
|-------|--------|
|
|
||||||
| **Once** | Allow this one time only (not saved) |
|
|
||||||
| **Agent** | Always allow for this agent (saves to agent `permissions.yaml`) |
|
|
||||||
| **Workspace** | Always allow for all agents (saves to `autogpt.yaml`) |
|
|
||||||
| **Deny** | Deny this command (saves to appropriate deny list) |
|
|
||||||
|
|
||||||
### Default Security
|
|
||||||
|
|
||||||
Out of the box, the following are **denied by default**:
|
|
||||||
- Reading sensitive files (`.env`, `.key`, `.pem`)
|
|
||||||
- Destructive shell commands (`rm -rf`, `sudo`)
|
|
||||||
- Operations outside the workspace directory
|
|
||||||
182
classic/CLI-USAGE.md
Executable file
182
classic/CLI-USAGE.md
Executable file
@@ -0,0 +1,182 @@
|
|||||||
|
## CLI Documentation
|
||||||
|
|
||||||
|
This document describes how to interact with the project's CLI (Command Line Interface). It includes the types of outputs you can expect from each command. Note that the `agents stop` command will terminate any process running on port 8000.
|
||||||
|
|
||||||
|
### 1. Entry Point for the CLI
|
||||||
|
|
||||||
|
Running the `./run` command without any parameters will display the help message, which provides a list of available commands and options. Additionally, you can append `--help` to any command to view help information specific to that command.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Usage: cli.py [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
agent Commands to create, start and stop agents
|
||||||
|
benchmark Commands to start the benchmark and list tests and categories
|
||||||
|
setup Installs dependencies needed for your system.
|
||||||
|
```
|
||||||
|
|
||||||
|
If you need assistance with any command, simply add the `--help` parameter to the end of your command, like so:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run COMMAND --help
|
||||||
|
```
|
||||||
|
|
||||||
|
This will display a detailed help message regarding that specific command, including a list of any additional options and arguments it accepts.
|
||||||
|
|
||||||
|
### 2. Setup Command
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run setup
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Setup initiated
|
||||||
|
Installation has been completed.
|
||||||
|
```
|
||||||
|
|
||||||
|
This command initializes the setup of the project.
|
||||||
|
|
||||||
|
### 3. Agents Commands
|
||||||
|
|
||||||
|
**a. List All Agents**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run agent list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Available agents: 🤖
|
||||||
|
🐙 forge
|
||||||
|
🐙 autogpt
|
||||||
|
```
|
||||||
|
|
||||||
|
Lists all the available agents.
|
||||||
|
|
||||||
|
**b. Create a New Agent**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run agent create my_agent
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
🎉 New agent 'my_agent' created and switched to the new directory in agents folder.
|
||||||
|
```
|
||||||
|
|
||||||
|
Creates a new agent named 'my_agent'.
|
||||||
|
|
||||||
|
**c. Start an Agent**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run agent start my_agent
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
... (ASCII Art representing the agent startup)
|
||||||
|
[Date and Time] [forge.sdk.db] [DEBUG] 🐛 Initializing AgentDB with database_string: sqlite:///agent.db
|
||||||
|
[Date and Time] [forge.sdk.agent] [INFO] 📝 Agent server starting on http://0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Starts the 'my_agent' and displays startup ASCII art and logs.
|
||||||
|
|
||||||
|
**d. Stop an Agent**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run agent stop
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Agent stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
Stops the running agent.
|
||||||
|
|
||||||
|
### 4. Benchmark Commands
|
||||||
|
|
||||||
|
**a. List Benchmark Categories**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run benchmark categories list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Available categories: 📚
|
||||||
|
📖 code
|
||||||
|
📖 safety
|
||||||
|
📖 memory
|
||||||
|
... (and so on)
|
||||||
|
```
|
||||||
|
|
||||||
|
Lists all available benchmark categories.
|
||||||
|
|
||||||
|
**b. List Benchmark Tests**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run benchmark tests list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Available tests: 📚
|
||||||
|
📖 interface
|
||||||
|
🔬 Search - TestSearch
|
||||||
|
🔬 Write File - TestWriteFile
|
||||||
|
... (and so on)
|
||||||
|
```
|
||||||
|
|
||||||
|
Lists all available benchmark tests.
|
||||||
|
|
||||||
|
**c. Show Details of a Benchmark Test**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run benchmark tests details TestWriteFile
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
TestWriteFile
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Category: interface
|
||||||
|
Task: Write the word 'Washington' to a .txt file
|
||||||
|
... (and other details)
|
||||||
|
```
|
||||||
|
|
||||||
|
Displays the details of the 'TestWriteFile' benchmark test.
|
||||||
|
|
||||||
|
**d. Start Benchmark for the Agent**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./run benchmark start my_agent
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output**:
|
||||||
|
|
||||||
|
```
|
||||||
|
(more details about the testing process shown whilst the test are running)
|
||||||
|
============= 13 failed, 1 passed in 0.97s ============...
|
||||||
|
```
|
||||||
|
|
||||||
|
Displays the results of the benchmark tests on 'my_agent'.
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
ARG BUILD_TYPE=dev
|
ARG BUILD_TYPE=dev
|
||||||
|
|
||||||
# Use an official Python base image from the Docker Hub
|
# Use an official Python base image from the Docker Hub
|
||||||
FROM python:3.12-slim AS autogpt-base
|
FROM python:3.10-slim AS autogpt-base
|
||||||
|
|
||||||
# Install browsers
|
# Install browsers
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
@@ -34,6 +34,9 @@ COPY original_autogpt/pyproject.toml original_autogpt/poetry.lock ./
|
|||||||
# Include forge so it can be used as a path dependency
|
# Include forge so it can be used as a path dependency
|
||||||
COPY forge/ ../forge
|
COPY forge/ ../forge
|
||||||
|
|
||||||
|
# Include frontend
|
||||||
|
COPY frontend/ ../frontend
|
||||||
|
|
||||||
# Set the entrypoint
|
# Set the entrypoint
|
||||||
ENTRYPOINT ["poetry", "run", "autogpt"]
|
ENTRYPOINT ["poetry", "run", "autogpt"]
|
||||||
CMD []
|
CMD []
|
||||||
|
|||||||
173
classic/FORGE-QUICKSTART.md
Normal file
173
classic/FORGE-QUICKSTART.md
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
# Quickstart Guide
|
||||||
|
|
||||||
|
> For the complete getting started [tutorial series](https://aiedge.medium.com/autogpt-forge-e3de53cc58ec) <- click here
|
||||||
|
|
||||||
|
Welcome to the Quickstart Guide! This guide will walk you through setting up, building, and running your own AutoGPT agent. Whether you're a seasoned AI developer or just starting out, this guide will provide you with the steps to jumpstart your journey in AI development with AutoGPT.
|
||||||
|
|
||||||
|
## System Requirements
|
||||||
|
|
||||||
|
This project supports Linux (Debian-based), Mac, and Windows Subsystem for Linux (WSL). If you use a Windows system, you must install WSL. You can find the installation instructions for WSL [here](https://learn.microsoft.com/en-us/windows/wsl/).
|
||||||
|
|
||||||
|
|
||||||
|
## Getting Setup
|
||||||
|
1. **Fork the Repository**
|
||||||
|
To fork the repository, follow these steps:
|
||||||
|
- Navigate to the main page of the repository.
|
||||||
|
|
||||||
|

|
||||||
|
- In the top-right corner of the page, click Fork.
|
||||||
|
|
||||||
|

|
||||||
|
- On the next page, select your GitHub account to create the fork.
|
||||||
|
- Wait for the forking process to complete. You now have a copy of the repository in your GitHub account.
|
||||||
|
|
||||||
|
2. **Clone the Repository**
|
||||||
|
To clone the repository, you need to have Git installed on your system. If you don't have Git installed, download it from [here](https://git-scm.com/downloads). Once you have Git installed, follow these steps:
|
||||||
|
- Open your terminal.
|
||||||
|
- Navigate to the directory where you want to clone the repository.
|
||||||
|
- Run the git clone command for the fork you just created
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
- Then open your project in your ide
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
4. **Setup the Project**
|
||||||
|
Next, we need to set up the required dependencies. We have a tool to help you perform all the tasks on the repo.
|
||||||
|
It can be accessed by running the `run` command by typing `./run` in the terminal.
|
||||||
|
|
||||||
|
The first command you need to use is `./run setup.` This will guide you through setting up your system.
|
||||||
|
Initially, you will get instructions for installing Flutter and Chrome and setting up your GitHub access token like the following image:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### For Windows Users
|
||||||
|
|
||||||
|
If you're a Windows user and experience issues after installing WSL, follow the steps below to resolve them.
|
||||||
|
|
||||||
|
#### Update WSL
|
||||||
|
Run the following command in Powershell or Command Prompt:
|
||||||
|
1. Enable the optional WSL and Virtual Machine Platform components.
|
||||||
|
2. Download and install the latest Linux kernel.
|
||||||
|
3. Set WSL 2 as the default.
|
||||||
|
4. Download and install the Ubuntu Linux distribution (a reboot may be required).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
wsl --install
|
||||||
|
```
|
||||||
|
|
||||||
|
For more detailed information and additional steps, refer to [Microsoft's WSL Setup Environment Documentation](https://learn.microsoft.com/en-us/windows/wsl/setup/environment).
|
||||||
|
|
||||||
|
#### Resolve FileNotFoundError or "No such file or directory" Errors
|
||||||
|
When you run `./run setup`, if you encounter errors like `No such file or directory` or `FileNotFoundError`, it might be because Windows-style line endings (CRLF - Carriage Return Line Feed) are not compatible with Unix/Linux style line endings (LF - Line Feed).
|
||||||
|
|
||||||
|
To resolve this, you can use the `dos2unix` utility to convert the line endings in your script from CRLF to LF. Here’s how to install and run `dos2unix` on the script:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install dos2unix
|
||||||
|
dos2unix ./run
|
||||||
|
```
|
||||||
|
|
||||||
|
After executing the above commands, running `./run setup` should work successfully.
|
||||||
|
|
||||||
|
#### Store Project Files within the WSL File System
|
||||||
|
If you continue to experience issues, consider storing your project files within the WSL file system instead of the Windows file system. This method avoids path translations and permissions issues and provides a more consistent development environment.
|
||||||
|
|
||||||
|
You can keep running the command to get feedback on where you are up to with your setup.
|
||||||
|
When setup has been completed, the command will return an output like this:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Creating Your Agent
|
||||||
|
|
||||||
|
After completing the setup, the next step is to create your agent template.
|
||||||
|
Execute the command `./run agent create YOUR_AGENT_NAME`, where `YOUR_AGENT_NAME` should be replaced with your chosen name.
|
||||||
|
|
||||||
|
Tips for naming your agent:
|
||||||
|
* Give it its own unique name, or name it after yourself
|
||||||
|
* Include an important aspect of your agent in the name, such as its purpose
|
||||||
|
|
||||||
|
Examples: `SwiftyosAssistant`, `PwutsPRAgent`, `MySuperAgent`
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Running your Agent
|
||||||
|
|
||||||
|
Your agent can be started using the command: `./run agent start YOUR_AGENT_NAME`
|
||||||
|
|
||||||
|
This starts the agent on the URL: `http://localhost:8000/`
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
The front end can be accessed from `http://localhost:8000/`; first, you must log in using either a Google account or your GitHub account.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Upon logging in, you will get a page that looks something like this: your task history down the left-hand side of the page, and the 'chat' window to send tasks to your agent.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
When you have finished with your agent or just need to restart it, use Ctl-C to end the session. Then, you can re-run the start command.
|
||||||
|
|
||||||
|
If you are having issues and want to ensure the agent has been stopped, there is a `./run agent stop` command, which will kill the process using port 8000, which should be the agent.
|
||||||
|
|
||||||
|
## Benchmarking your Agent
|
||||||
|
|
||||||
|
The benchmarking system can also be accessed using the CLI too:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
agpt % ./run benchmark
|
||||||
|
Usage: cli.py benchmark [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Commands to start the benchmark and list tests and categories
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
categories Benchmark categories group command
|
||||||
|
start Starts the benchmark command
|
||||||
|
tests Benchmark tests group command
|
||||||
|
agpt % ./run benchmark categories
|
||||||
|
Usage: cli.py benchmark categories [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Benchmark categories group command
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
list List benchmark categories command
|
||||||
|
agpt % ./run benchmark tests
|
||||||
|
Usage: cli.py benchmark tests [OPTIONS] COMMAND [ARGS]...
|
||||||
|
|
||||||
|
Benchmark tests group command
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--help Show this message and exit.
|
||||||
|
|
||||||
|
Commands:
|
||||||
|
details Benchmark test details command
|
||||||
|
list List benchmark tests command
|
||||||
|
```
|
||||||
|
|
||||||
|
The benchmark has been split into different categories of skills you can test your agent on. You can see what categories are available with
|
||||||
|
```bash
|
||||||
|
./run benchmark categories list
|
||||||
|
# And what tests are available with
|
||||||
|
./run benchmark tests list
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
Finally, you can run the benchmark with
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./run benchmark start YOUR_AGENT_NAME
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
>
|
||||||
@@ -4,7 +4,7 @@ AutoGPT Classic was an experimental project to demonstrate autonomous GPT-4 oper
|
|||||||
|
|
||||||
## Project Status
|
## Project Status
|
||||||
|
|
||||||
**This project is unsupported, and dependencies will not be updated.** It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform).
|
⚠️ **This project is unsupported, and dependencies will not be updated. It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform)**
|
||||||
|
|
||||||
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
||||||
|
|
||||||
@@ -16,171 +16,37 @@ AutoGPT Classic was one of the first implementations of autonomous AI agents - A
|
|||||||
- Learn from the results and adjust its approach
|
- Learn from the results and adjust its approach
|
||||||
- Chain multiple actions together to achieve an objective
|
- Chain multiple actions together to achieve an objective
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- 🔄 Autonomous task chaining
|
||||||
|
- 🛠 Tool and API integration capabilities
|
||||||
|
- 💾 Memory management for context retention
|
||||||
|
- 🔍 Web browsing and information gathering
|
||||||
|
- 📝 File operations and content creation
|
||||||
|
- 🔄 Self-prompting and task breakdown
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
```
|
The project is organized into several key components:
|
||||||
classic/
|
- `/benchmark` - Performance testing tools
|
||||||
├── pyproject.toml # Single consolidated Poetry project
|
- `/forge` - Core autonomous agent framework
|
||||||
├── poetry.lock # Single lock file
|
- `/frontend` - User interface components
|
||||||
├── forge/ # Core autonomous agent framework
|
- `/original_autogpt` - Original implementation
|
||||||
├── original_autogpt/ # Original implementation
|
|
||||||
├── direct_benchmark/ # Benchmark harness
|
|
||||||
└── benchmark/ # Challenge definitions (data)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Prerequisites
|
While this project is no longer actively maintained, you can still explore the codebase:
|
||||||
|
|
||||||
- Python 3.12+
|
|
||||||
- [Poetry](https://python-poetry.org/docs/#installation)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
```bash
|
```bash
|
||||||
# Clone the repository
|
|
||||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||||
cd classic
|
cd classic
|
||||||
|
|
||||||
# Install everything
|
|
||||||
poetry install
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration
|
2. Review the documentation:
|
||||||
|
- For reference, see the [documentation](https://docs.agpt.co). You can browse at the same point in time as this commit so the docs don't change.
|
||||||
Configuration uses a layered system:
|
- Check `CLI-USAGE.md` for command-line interface details
|
||||||
|
- Refer to `TROUBLESHOOTING.md` for common issues
|
||||||
1. **Environment variables** (`.env` file)
|
|
||||||
2. **Workspace settings** (`.autogpt/autogpt.yaml`)
|
|
||||||
3. **Agent settings** (`.autogpt/agents/{id}/permissions.yaml`)
|
|
||||||
|
|
||||||
Copy the example environment file and add your API keys:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
```
|
|
||||||
|
|
||||||
Key environment variables:
|
|
||||||
```bash
|
|
||||||
# Required
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Optional LLM settings
|
|
||||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
|
||||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
|
||||||
|
|
||||||
# Optional search providers
|
|
||||||
TAVILY_API_KEY=tvly-...
|
|
||||||
SERPER_API_KEY=...
|
|
||||||
|
|
||||||
# Optional infrastructure
|
|
||||||
LOG_LEVEL=DEBUG
|
|
||||||
PORT=8000
|
|
||||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running
|
|
||||||
|
|
||||||
All commands run from the `classic/` directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run forge agent
|
|
||||||
poetry run python -m forge
|
|
||||||
|
|
||||||
# Run original autogpt server
|
|
||||||
poetry run serve --debug
|
|
||||||
|
|
||||||
# Run autogpt CLI
|
|
||||||
poetry run autogpt
|
|
||||||
```
|
|
||||||
|
|
||||||
Agents run on `http://localhost:8000` by default.
|
|
||||||
|
|
||||||
### Benchmarking
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest # All tests
|
|
||||||
poetry run pytest forge/tests/ # Forge tests only
|
|
||||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
|
||||||
```
|
|
||||||
|
|
||||||
## Workspaces
|
|
||||||
|
|
||||||
Agents operate within a **workspace** directory that contains all agent data and files:
|
|
||||||
|
|
||||||
```
|
|
||||||
{workspace}/
|
|
||||||
├── .autogpt/
|
|
||||||
│ ├── autogpt.yaml # Workspace-level permissions
|
|
||||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
|
||||||
│ └── agents/
|
|
||||||
│ └── AutoGPT-{agent_id}/
|
|
||||||
│ ├── state.json # Agent profile, directives, history
|
|
||||||
│ ├── permissions.yaml # Agent-specific permissions
|
|
||||||
│ └── workspace/ # Agent's sandboxed working directory
|
|
||||||
```
|
|
||||||
|
|
||||||
- The workspace defaults to the current working directory
|
|
||||||
- Multiple agents can coexist in the same workspace
|
|
||||||
- Agent file access is sandboxed to their `workspace/` subdirectory
|
|
||||||
- State persists across sessions via `state.json`
|
|
||||||
|
|
||||||
## Permissions
|
|
||||||
|
|
||||||
AutoGPT uses a **layered permission system** with pattern matching:
|
|
||||||
|
|
||||||
### Permission Files
|
|
||||||
|
|
||||||
| File | Scope | Location |
|
|
||||||
|------|-------|----------|
|
|
||||||
| `autogpt.yaml` | All agents in workspace | `.autogpt/autogpt.yaml` |
|
|
||||||
| `permissions.yaml` | Single agent | `.autogpt/agents/{id}/permissions.yaml` |
|
|
||||||
|
|
||||||
### Permission Format
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- read_file({workspace}/**) # Read any file in workspace
|
|
||||||
- write_to_file({workspace}/**) # Write any file in workspace
|
|
||||||
- web_search(*) # All web searches
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- read_file(**.env) # Block .env files
|
|
||||||
- execute_shell(sudo:*) # Block sudo commands
|
|
||||||
```
|
|
||||||
|
|
||||||
### Check Order (First Match Wins)
|
|
||||||
|
|
||||||
1. Agent deny → Block
|
|
||||||
2. Workspace deny → Block
|
|
||||||
3. Agent allow → Allow
|
|
||||||
4. Workspace allow → Allow
|
|
||||||
5. Prompt user → Interactive approval
|
|
||||||
|
|
||||||
### Interactive Approval
|
|
||||||
|
|
||||||
When prompted, users can approve commands with different scopes:
|
|
||||||
- **Once** - Allow this one time only
|
|
||||||
- **Agent** - Always allow for this agent
|
|
||||||
- **Workspace** - Always allow for all agents
|
|
||||||
- **Deny** - Block this command
|
|
||||||
|
|
||||||
### Default Security
|
|
||||||
|
|
||||||
Denied by default:
|
|
||||||
- Sensitive files (`.env`, `.key`, `.pem`)
|
|
||||||
- Destructive commands (`rm -rf`, `sudo`)
|
|
||||||
- Operations outside the workspace
|
|
||||||
|
|
||||||
## Security Notice
|
|
||||||
|
|
||||||
This codebase has **known vulnerabilities** and issues with its dependencies. It will not be updated to new dependencies. Use for educational purposes only.
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
@@ -189,3 +55,27 @@ This project segment is licensed under the MIT License - see the [LICENSE](LICEN
|
|||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
||||||
|
You can browse at the same point in time as this commit so the docs don't change.
|
||||||
|
|
||||||
|
## Historical Impact
|
||||||
|
|
||||||
|
AutoGPT Classic played a significant role in advancing the field of autonomous AI agents:
|
||||||
|
- Demonstrated practical implementation of AI autonomy
|
||||||
|
- Inspired numerous derivative projects and research
|
||||||
|
- Contributed to the development of AI agent architectures
|
||||||
|
- Helped identify key challenges in AI autonomy
|
||||||
|
|
||||||
|
## Security Notice
|
||||||
|
|
||||||
|
If you're studying this codebase, please understand this has KNOWN vulnerabilities and issues with its dependencies. It will not be updated to new dependencies.
|
||||||
|
|
||||||
|
## Community & Support
|
||||||
|
|
||||||
|
While active development has concluded:
|
||||||
|
- The codebase remains available for study and reference
|
||||||
|
- Historical discussions can be found in project issues
|
||||||
|
- Related research and developments continue in the broader AI agent community
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Thanks to all contributors who participated in this experimental project and helped advance the field of autonomous AI agents.
|
||||||
|
|||||||
4
classic/benchmark/.env.example
Normal file
4
classic/benchmark/.env.example
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
AGENT_NAME=mini-agi
|
||||||
|
REPORTS_FOLDER="reports/mini-agi"
|
||||||
|
OPENAI_API_KEY="sk-" # for LLM eval
|
||||||
|
BUILD_SKILL_TREE=false # set to true to build the skill tree.
|
||||||
12
classic/benchmark/.flake8
Normal file
12
classic/benchmark/.flake8
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
# Ignore rules that conflict with Black code style
|
||||||
|
extend-ignore = E203, W503
|
||||||
|
exclude =
|
||||||
|
__pycache__/,
|
||||||
|
*.pyc,
|
||||||
|
.pytest_cache/,
|
||||||
|
venv*/,
|
||||||
|
.venv/,
|
||||||
|
reports/,
|
||||||
|
agbenchmark/reports/,
|
||||||
174
classic/benchmark/.gitignore
vendored
Normal file
174
classic/benchmark/.gitignore
vendored
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
agbenchmark_config/workspace/
|
||||||
|
backend/backend_stdout.txt
|
||||||
|
reports/df*.pkl
|
||||||
|
reports/raw*
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
.DS_Store
|
||||||
|
```
|
||||||
|
secrets.json
|
||||||
|
agbenchmark_config/challenges_already_beaten.json
|
||||||
|
agbenchmark_config/challenges/pri_*
|
||||||
|
agbenchmark_config/updates.json
|
||||||
|
agbenchmark_config/reports/*
|
||||||
|
agbenchmark_config/reports/success_rate.json
|
||||||
|
agbenchmark_config/reports/regression_tests.json
|
||||||
21
classic/benchmark/LICENSE
Normal file
21
classic/benchmark/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 AutoGPT
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
25
classic/benchmark/README.md
Normal file
25
classic/benchmark/README.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Auto-GPT Benchmarks
|
||||||
|
|
||||||
|
Built for the purpose of benchmarking the performance of agents regardless of how they work.
|
||||||
|
|
||||||
|
Objectively know how well your agent is performing in categories like code, retrieval, memory, and safety.
|
||||||
|
|
||||||
|
Save time and money while doing it through smart dependencies. The best part? It's all automated.
|
||||||
|
|
||||||
|
## Scores:
|
||||||
|
|
||||||
|
<img width="733" alt="Screenshot 2023-07-25 at 10 35 01 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/98963e0b-18b9-4b17-9a6a-4d3e4418af70">
|
||||||
|
|
||||||
|
## Ranking overall:
|
||||||
|
|
||||||
|
- 1- [Beebot](https://github.com/AutoPackAI/beebot)
|
||||||
|
- 2- [mini-agi](https://github.com/muellerberndt/mini-agi)
|
||||||
|
- 3- [Auto-GPT](https://github.com/Significant-Gravitas/AutoGPT)
|
||||||
|
|
||||||
|
## Detailed results:
|
||||||
|
|
||||||
|
<img width="733" alt="Screenshot 2023-07-25 at 10 42 15 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/39be464c-c842-4437-b28a-07d878542a83">
|
||||||
|
|
||||||
|
[Click here to see the results and the raw data!](https://docs.google.com/spreadsheets/d/1WXm16P2AHNbKpkOI0LYBpcsGG0O7D8HYTG5Uj0PaJjA/edit#gid=203558751)!
|
||||||
|
|
||||||
|
More agents coming soon !
|
||||||
69
classic/benchmark/agbenchmark/README.md
Normal file
69
classic/benchmark/agbenchmark/README.md
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
## As a user
|
||||||
|
|
||||||
|
1. `pip install auto-gpt-benchmarks`
|
||||||
|
2. Add boilerplate code to run and kill agent
|
||||||
|
3. `agbenchmark`
|
||||||
|
- `--category challenge_category` to run tests in a specific category
|
||||||
|
- `--mock` to only run mock tests if they exists for each test
|
||||||
|
- `--noreg` to skip any tests that have passed in the past. When you run without this flag and a previous challenge that passed fails, it will now not be regression tests
|
||||||
|
4. We call boilerplate code for your agent
|
||||||
|
5. Show pass rate of tests, logs, and any other metrics
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
##### Diagrams: https://whimsical.com/agbenchmark-5n4hXBq1ZGzBwRsK4TVY7x
|
||||||
|
|
||||||
|
### To run the existing mocks
|
||||||
|
|
||||||
|
1. clone the repo `auto-gpt-benchmarks`
|
||||||
|
2. `pip install poetry`
|
||||||
|
3. `poetry shell`
|
||||||
|
4. `poetry install`
|
||||||
|
5. `cp .env_example .env`
|
||||||
|
6. `git submodule update --init --remote --recursive`
|
||||||
|
7. `uvicorn server:app --reload`
|
||||||
|
8. `agbenchmark --mock`
|
||||||
|
Keep config the same and watch the logs :)
|
||||||
|
|
||||||
|
### To run with mini-agi
|
||||||
|
|
||||||
|
1. Navigate to `auto-gpt-benchmarks/agent/mini-agi`
|
||||||
|
2. `pip install -r requirements.txt`
|
||||||
|
3. `cp .env_example .env`, set `PROMPT_USER=false` and add your `OPENAI_API_KEY=`. Sset `MODEL="gpt-3.5-turbo"` if you don't have access to `gpt-4` yet. Also make sure you have Python 3.10^ installed
|
||||||
|
4. set `AGENT_NAME=mini-agi` in `.env` file and where you want your `REPORTS_FOLDER` to be
|
||||||
|
5. Make sure to follow the commands above, and remove mock flag `agbenchmark`
|
||||||
|
|
||||||
|
- To add requirements `poetry add requirement`.
|
||||||
|
|
||||||
|
Feel free to create prs to merge with `main` at will (but also feel free to ask for review) - if you can't send msg in R&D chat for access.
|
||||||
|
|
||||||
|
If you push at any point and break things - it'll happen to everyone - fix it asap. Step 1 is to revert `master` to last working commit
|
||||||
|
|
||||||
|
Let people know what beautiful code you write does, document everything well
|
||||||
|
|
||||||
|
Share your progress :)
|
||||||
|
|
||||||
|
#### Dataset
|
||||||
|
|
||||||
|
Manually created, existing challenges within Auto-Gpt, https://osu-nlp-group.github.io/Mind2Web/
|
||||||
|
|
||||||
|
## How do I add new agents to agbenchmark ?
|
||||||
|
|
||||||
|
Example with smol developer.
|
||||||
|
|
||||||
|
1- Create a github branch with your agent following the same pattern as this example:
|
||||||
|
|
||||||
|
https://github.com/smol-ai/developer/pull/114/files
|
||||||
|
|
||||||
|
2- Create the submodule and the github workflow by following the same pattern as this example:
|
||||||
|
|
||||||
|
https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/pull/48/files
|
||||||
|
|
||||||
|
## How do I run agent in different environments?
|
||||||
|
|
||||||
|
**To just use as the benchmark for your agent**. `pip install` the package and run `agbenchmark`
|
||||||
|
|
||||||
|
**For internal Auto-GPT ci runs**, specify the `AGENT_NAME` you want you use and set the `HOME_ENV`.
|
||||||
|
Ex. `AGENT_NAME=mini-agi`
|
||||||
|
|
||||||
|
**To develop agent alongside benchmark**, you can specify the `AGENT_NAME` you want you use and add as a submodule to the repo
|
||||||
352
classic/benchmark/agbenchmark/__main__.py
Normal file
352
classic/benchmark/agbenchmark/__main__.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
from click_default_group import DefaultGroup
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from agbenchmark.config import AgentBenchmarkConfig
|
||||||
|
from agbenchmark.utils.logging import configure_logging
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# if os.getenv("HELICONE_API_KEY"):
|
||||||
|
# import helicone # noqa
|
||||||
|
|
||||||
|
# helicone_enabled = True
|
||||||
|
# else:
|
||||||
|
# helicone_enabled = False
|
||||||
|
# except ImportError:
|
||||||
|
# helicone_enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInvocationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BENCHMARK_START_TIME_DT = datetime.now(timezone.utc)
|
||||||
|
BENCHMARK_START_TIME = BENCHMARK_START_TIME_DT.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||||
|
|
||||||
|
|
||||||
|
# if helicone_enabled:
|
||||||
|
# from helicone.lock import HeliconeLockManager
|
||||||
|
|
||||||
|
# HeliconeLockManager.write_custom_property(
|
||||||
|
# "benchmark_start_time", BENCHMARK_START_TIME
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(cls=DefaultGroup, default_if_no_args=True)
|
||||||
|
@click.option("--debug", is_flag=True, help="Enable debug output")
|
||||||
|
def cli(
|
||||||
|
debug: bool,
|
||||||
|
) -> Any:
|
||||||
|
configure_logging(logging.DEBUG if debug else logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(hidden=True)
|
||||||
|
def start():
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`agbenchmark start` is deprecated. Use `agbenchmark run` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(default=True)
|
||||||
|
@click.option(
|
||||||
|
"-N", "--attempts", default=1, help="Number of times to run each challenge."
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-c",
|
||||||
|
"--category",
|
||||||
|
multiple=True,
|
||||||
|
help="(+) Select a category to run.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-s",
|
||||||
|
"--skip-category",
|
||||||
|
multiple=True,
|
||||||
|
help="(+) Exclude a category from running.",
|
||||||
|
)
|
||||||
|
@click.option("--test", multiple=True, help="(+) Select a test to run.")
|
||||||
|
@click.option("--maintain", is_flag=True, help="Run only regression tests.")
|
||||||
|
@click.option("--improve", is_flag=True, help="Run only non-regression tests.")
|
||||||
|
@click.option(
|
||||||
|
"--explore",
|
||||||
|
is_flag=True,
|
||||||
|
help="Run only challenges that have never been beaten.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--no-dep",
|
||||||
|
is_flag=True,
|
||||||
|
help="Run all (selected) challenges, regardless of dependency success/failure.",
|
||||||
|
)
|
||||||
|
@click.option("--cutoff", type=int, help="Override the challenge time limit (seconds).")
|
||||||
|
@click.option("--nc", is_flag=True, help="Disable the challenge time limit.")
|
||||||
|
@click.option("--mock", is_flag=True, help="Run with mock")
|
||||||
|
@click.option("--keep-answers", is_flag=True, help="Keep answers")
|
||||||
|
@click.option(
|
||||||
|
"--backend",
|
||||||
|
is_flag=True,
|
||||||
|
help="Write log output to a file instead of the terminal.",
|
||||||
|
)
|
||||||
|
# @click.argument(
|
||||||
|
# "agent_path",
|
||||||
|
# type=click.Path(exists=True, file_okay=False, path_type=Path),
|
||||||
|
# required=False,
|
||||||
|
# )
|
||||||
|
def run(
|
||||||
|
maintain: bool,
|
||||||
|
improve: bool,
|
||||||
|
explore: bool,
|
||||||
|
mock: bool,
|
||||||
|
no_dep: bool,
|
||||||
|
nc: bool,
|
||||||
|
keep_answers: bool,
|
||||||
|
test: tuple[str],
|
||||||
|
category: tuple[str],
|
||||||
|
skip_category: tuple[str],
|
||||||
|
attempts: int,
|
||||||
|
cutoff: Optional[int] = None,
|
||||||
|
backend: Optional[bool] = False,
|
||||||
|
# agent_path: Optional[Path] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run the benchmark on the agent in the current directory.
|
||||||
|
|
||||||
|
Options marked with (+) can be specified multiple times, to select multiple items.
|
||||||
|
"""
|
||||||
|
from agbenchmark.main import run_benchmark, validate_args
|
||||||
|
|
||||||
|
agbenchmark_config = AgentBenchmarkConfig.load()
|
||||||
|
logger.debug(f"agbenchmark_config: {agbenchmark_config.agbenchmark_config_dir}")
|
||||||
|
try:
|
||||||
|
validate_args(
|
||||||
|
maintain=maintain,
|
||||||
|
improve=improve,
|
||||||
|
explore=explore,
|
||||||
|
tests=test,
|
||||||
|
categories=category,
|
||||||
|
skip_categories=skip_category,
|
||||||
|
no_cutoff=nc,
|
||||||
|
cutoff=cutoff,
|
||||||
|
)
|
||||||
|
except InvalidInvocationError as e:
|
||||||
|
logger.error("Error: " + "\n".join(e.args))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
original_stdout = sys.stdout # Save the original standard output
|
||||||
|
exit_code = None
|
||||||
|
|
||||||
|
if backend:
|
||||||
|
with open("backend/backend_stdout.txt", "w") as f:
|
||||||
|
sys.stdout = f
|
||||||
|
exit_code = run_benchmark(
|
||||||
|
config=agbenchmark_config,
|
||||||
|
maintain=maintain,
|
||||||
|
improve=improve,
|
||||||
|
explore=explore,
|
||||||
|
mock=mock,
|
||||||
|
no_dep=no_dep,
|
||||||
|
no_cutoff=nc,
|
||||||
|
keep_answers=keep_answers,
|
||||||
|
tests=test,
|
||||||
|
categories=category,
|
||||||
|
skip_categories=skip_category,
|
||||||
|
attempts_per_challenge=attempts,
|
||||||
|
cutoff=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.stdout = original_stdout
|
||||||
|
|
||||||
|
else:
|
||||||
|
exit_code = run_benchmark(
|
||||||
|
config=agbenchmark_config,
|
||||||
|
maintain=maintain,
|
||||||
|
improve=improve,
|
||||||
|
explore=explore,
|
||||||
|
mock=mock,
|
||||||
|
no_dep=no_dep,
|
||||||
|
no_cutoff=nc,
|
||||||
|
keep_answers=keep_answers,
|
||||||
|
tests=test,
|
||||||
|
categories=category,
|
||||||
|
skip_categories=skip_category,
|
||||||
|
attempts_per_challenge=attempts,
|
||||||
|
cutoff=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.exit(exit_code)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.option("--port", type=int, help="Port to run the API on.")
|
||||||
|
def serve(port: Optional[int] = None):
|
||||||
|
"""Serve the benchmark frontend and API on port 8080."""
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from agbenchmark.app import setup_fastapi_app
|
||||||
|
|
||||||
|
config = AgentBenchmarkConfig.load()
|
||||||
|
app = setup_fastapi_app(config)
|
||||||
|
|
||||||
|
# Run the FastAPI application using uvicorn
|
||||||
|
port = port or int(os.getenv("PORT", 8080))
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def config():
|
||||||
|
"""Displays info regarding the present AGBenchmark config."""
|
||||||
|
from .utils.utils import pretty_print_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = AgentBenchmarkConfig.load()
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
click.echo(e, err=True)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
pretty_print_model(config, include_header=False)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.group()
|
||||||
|
def challenge():
|
||||||
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
@challenge.command("list")
|
||||||
|
@click.option(
|
||||||
|
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--names", "only_names", is_flag=True, help="List only the challenge names."
|
||||||
|
)
|
||||||
|
@click.option("--json", "output_json", is_flag=True)
|
||||||
|
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
|
||||||
|
"""Lists [available|all] challenges."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from .challenges.builtin import load_builtin_challenges
|
||||||
|
from .challenges.webarena import load_webarena_challenges
|
||||||
|
from .utils.data_types import Category, DifficultyLevel
|
||||||
|
from .utils.utils import sorted_by_enum_index
|
||||||
|
|
||||||
|
DIFFICULTY_COLORS = {
|
||||||
|
difficulty: color
|
||||||
|
for difficulty, color in zip(
|
||||||
|
DifficultyLevel,
|
||||||
|
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CATEGORY_COLORS = {
|
||||||
|
category: f"bright_{color}"
|
||||||
|
for category, color in zip(
|
||||||
|
Category,
|
||||||
|
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load challenges
|
||||||
|
challenges = filter(
|
||||||
|
lambda c: c.info.available or include_unavailable,
|
||||||
|
[
|
||||||
|
*load_builtin_challenges(),
|
||||||
|
*load_webarena_challenges(skip_unavailable=False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
challenges = sorted_by_enum_index(
|
||||||
|
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
|
||||||
|
)
|
||||||
|
|
||||||
|
if only_names:
|
||||||
|
if output_json:
|
||||||
|
click.echo(json.dumps([c.info.name for c in challenges]))
|
||||||
|
return
|
||||||
|
|
||||||
|
for c in challenges:
|
||||||
|
click.echo(
|
||||||
|
click.style(c.info.name, fg=None if c.info.available else "black")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_json:
|
||||||
|
click.echo(
|
||||||
|
json.dumps([json.loads(c.info.model_dump_json()) for c in challenges])
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = tuple(
|
||||||
|
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
|
||||||
|
)
|
||||||
|
table = [
|
||||||
|
tuple(
|
||||||
|
v if challenge.info.available else click.style(v, fg="black")
|
||||||
|
for v in (
|
||||||
|
challenge.info.name,
|
||||||
|
(
|
||||||
|
click.style(
|
||||||
|
challenge.info.difficulty.value,
|
||||||
|
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
|
||||||
|
)
|
||||||
|
if challenge.info.difficulty
|
||||||
|
else click.style("-", fg="black")
|
||||||
|
),
|
||||||
|
" ".join(
|
||||||
|
click.style(cat.value, fg=CATEGORY_COLORS[cat])
|
||||||
|
for cat in sorted_by_enum_index(challenge.info.category, Category)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for challenge in challenges
|
||||||
|
]
|
||||||
|
click.echo(tabulate(table, headers=headers))
|
||||||
|
|
||||||
|
|
||||||
|
@challenge.command()
|
||||||
|
@click.option("--json", is_flag=True)
|
||||||
|
@click.argument("name")
|
||||||
|
def info(name: str, json: bool):
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from .challenges.builtin import load_builtin_challenges
|
||||||
|
from .challenges.webarena import load_webarena_challenges
|
||||||
|
from .utils.utils import pretty_print_model
|
||||||
|
|
||||||
|
for challenge in chain(
|
||||||
|
load_builtin_challenges(),
|
||||||
|
load_webarena_challenges(skip_unavailable=False),
|
||||||
|
):
|
||||||
|
if challenge.info.name != name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if json:
|
||||||
|
click.echo(challenge.info.model_dump_json())
|
||||||
|
break
|
||||||
|
|
||||||
|
pretty_print_model(challenge.info)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def version():
|
||||||
|
"""Print version info for the AGBenchmark application."""
|
||||||
|
import toml
|
||||||
|
|
||||||
|
package_root = Path(__file__).resolve().parent.parent
|
||||||
|
pyproject = toml.load(package_root / "pyproject.toml")
|
||||||
|
version = pyproject["tool"]["poetry"]["version"]
|
||||||
|
click.echo(f"AGBenchmark version {version}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
111
classic/benchmark/agbenchmark/agent_api_interface.py
Normal file
111
classic/benchmark/agbenchmark/agent_api_interface.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
from agent_protocol_client import (
|
||||||
|
AgentApi,
|
||||||
|
ApiClient,
|
||||||
|
Configuration,
|
||||||
|
Step,
|
||||||
|
TaskRequestBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
from agbenchmark.agent_interface import get_list_of_file_paths
|
||||||
|
from agbenchmark.config import AgentBenchmarkConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_api_agent(
|
||||||
|
task: str,
|
||||||
|
config: AgentBenchmarkConfig,
|
||||||
|
timeout: int,
|
||||||
|
artifacts_location: Optional[Path] = None,
|
||||||
|
*,
|
||||||
|
mock: bool = False,
|
||||||
|
) -> AsyncIterator[Step]:
|
||||||
|
configuration = Configuration(host=config.host)
|
||||||
|
async with ApiClient(configuration) as api_client:
|
||||||
|
api_instance = AgentApi(api_client)
|
||||||
|
task_request_body = TaskRequestBody(input=task, additional_input=None)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = await api_instance.create_agent_task(
|
||||||
|
task_request_body=task_request_body
|
||||||
|
)
|
||||||
|
task_id = response.task_id
|
||||||
|
|
||||||
|
if artifacts_location:
|
||||||
|
logger.debug("Uploading task input artifacts to agent...")
|
||||||
|
await upload_artifacts(
|
||||||
|
api_instance, artifacts_location, task_id, "artifacts_in"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Running agent until finished or timeout...")
|
||||||
|
while True:
|
||||||
|
step = await api_instance.execute_agent_task_step(task_id=task_id)
|
||||||
|
yield step
|
||||||
|
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise TimeoutError("Time limit exceeded")
|
||||||
|
if step and mock:
|
||||||
|
step.is_last = True
|
||||||
|
if not step or step.is_last:
|
||||||
|
break
|
||||||
|
|
||||||
|
if artifacts_location:
|
||||||
|
# In "mock" mode, we cheat by giving the correct artifacts to pass the test
|
||||||
|
if mock:
|
||||||
|
logger.debug("Uploading mock artifacts to agent...")
|
||||||
|
await upload_artifacts(
|
||||||
|
api_instance, artifacts_location, task_id, "artifacts_out"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Downloading agent artifacts...")
|
||||||
|
await download_agent_artifacts_into_folder(
|
||||||
|
api_instance, task_id, config.temp_folder
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_agent_artifacts_into_folder(
|
||||||
|
api_instance: AgentApi, task_id: str, folder: Path
|
||||||
|
):
|
||||||
|
artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)
|
||||||
|
|
||||||
|
for artifact in artifacts.artifacts:
|
||||||
|
# current absolute path of the directory of the file
|
||||||
|
if artifact.relative_path:
|
||||||
|
path: str = (
|
||||||
|
artifact.relative_path
|
||||||
|
if not artifact.relative_path.startswith("/")
|
||||||
|
else artifact.relative_path[1:]
|
||||||
|
)
|
||||||
|
folder = (folder / path).parent
|
||||||
|
|
||||||
|
if not folder.exists():
|
||||||
|
folder.mkdir(parents=True)
|
||||||
|
|
||||||
|
file_path = folder / artifact.file_name
|
||||||
|
logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
content = await api_instance.download_agent_task_artifact(
|
||||||
|
task_id=task_id, artifact_id=artifact.artifact_id
|
||||||
|
)
|
||||||
|
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_artifacts(
|
||||||
|
api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
|
||||||
|
) -> None:
|
||||||
|
for file_path in get_list_of_file_paths(artifacts_location, type):
|
||||||
|
relative_path: Optional[str] = "/".join(
|
||||||
|
str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1]
|
||||||
|
)
|
||||||
|
if not relative_path:
|
||||||
|
relative_path = None
|
||||||
|
|
||||||
|
await api_instance.upload_agent_task_artifacts(
|
||||||
|
task_id=task_id, file=str(file_path), relative_path=relative_path
|
||||||
|
)
|
||||||
27
classic/benchmark/agbenchmark/agent_interface.py
Normal file
27
classic/benchmark/agbenchmark/agent_interface.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
HELICONE_GRAPHQL_LOGS = os.getenv("HELICONE_GRAPHQL_LOGS", "").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
def get_list_of_file_paths(
|
||||||
|
challenge_dir_path: str | Path, artifact_folder_name: str
|
||||||
|
) -> list[Path]:
|
||||||
|
source_dir = Path(challenge_dir_path) / artifact_folder_name
|
||||||
|
if not source_dir.exists():
|
||||||
|
return []
|
||||||
|
return list(source_dir.iterdir())
|
||||||
|
|
||||||
|
|
||||||
|
def copy_challenge_artifacts_into_workspace(
|
||||||
|
challenge_dir_path: str | Path, artifact_folder_name: str, workspace: str | Path
|
||||||
|
) -> None:
|
||||||
|
file_paths = get_list_of_file_paths(challenge_dir_path, artifact_folder_name)
|
||||||
|
for file_path in file_paths:
|
||||||
|
if file_path.is_file():
|
||||||
|
shutil.copy(file_path, workspace)
|
||||||
339
classic/benchmark/agbenchmark/app.py
Normal file
339
classic/benchmark/agbenchmark/app.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
import datetime
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from multiprocessing import Process
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import psutil
|
||||||
|
from agent_protocol_client import AgentApi, ApiClient, ApiException, Configuration
|
||||||
|
from agent_protocol_client.models import Task, TaskRequestBody
|
||||||
|
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||||
|
|
||||||
|
from agbenchmark.challenges import ChallengeInfo
|
||||||
|
from agbenchmark.config import AgentBenchmarkConfig
|
||||||
|
from agbenchmark.reports.processing.report_types_v2 import (
|
||||||
|
BenchmarkRun,
|
||||||
|
Metrics,
|
||||||
|
RepositoryInfo,
|
||||||
|
RunDetails,
|
||||||
|
TaskInfo,
|
||||||
|
)
|
||||||
|
from agbenchmark.schema import TaskEvalRequestBody
|
||||||
|
from agbenchmark.utils.utils import write_pretty_json
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CHALLENGES: dict[str, ChallengeInfo] = {}
|
||||||
|
challenges_path = Path(__file__).parent / "challenges"
|
||||||
|
challenge_spec_files = deque(
|
||||||
|
glob.glob(
|
||||||
|
f"{challenges_path}/**/data.json",
|
||||||
|
recursive=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Loading challenges...")
|
||||||
|
while challenge_spec_files:
|
||||||
|
challenge_spec_file = Path(challenge_spec_files.popleft())
|
||||||
|
challenge_relpath = challenge_spec_file.relative_to(challenges_path.parent)
|
||||||
|
if challenge_relpath.is_relative_to("challenges/deprecated"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Loading {challenge_relpath}...")
|
||||||
|
try:
|
||||||
|
challenge_info = ChallengeInfo.model_validate_json(
|
||||||
|
challenge_spec_file.read_text()
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
if logging.getLogger().level == logging.DEBUG:
|
||||||
|
logger.warning(f"Spec file {challenge_relpath} failed to load:\n{e}")
|
||||||
|
logger.debug(f"Invalid challenge spec: {challenge_spec_file.read_text()}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not challenge_info.eval_id:
|
||||||
|
challenge_info.eval_id = str(uuid.uuid4())
|
||||||
|
# this will sort all the keys of the JSON systematically
|
||||||
|
# so that the order is always the same
|
||||||
|
write_pretty_json(challenge_info.model_dump(), challenge_spec_file)
|
||||||
|
|
||||||
|
CHALLENGES[challenge_info.eval_id] = challenge_info
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkTaskInfo(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
start_time: datetime.datetime
|
||||||
|
challenge_info: ChallengeInfo
|
||||||
|
|
||||||
|
|
||||||
|
task_informations: dict[str, BenchmarkTaskInfo] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def find_agbenchmark_without_uvicorn():
|
||||||
|
pids = []
|
||||||
|
for process in psutil.process_iter(
|
||||||
|
attrs=[
|
||||||
|
"pid",
|
||||||
|
"cmdline",
|
||||||
|
"name",
|
||||||
|
"username",
|
||||||
|
"status",
|
||||||
|
"cpu_percent",
|
||||||
|
"memory_info",
|
||||||
|
"create_time",
|
||||||
|
"cwd",
|
||||||
|
"connections",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Convert the process.info dictionary values to strings and concatenate them
|
||||||
|
full_info = " ".join([str(v) for k, v in process.as_dict().items()])
|
||||||
|
|
||||||
|
if "agbenchmark" in full_info and "uvicorn" not in full_info:
|
||||||
|
pids.append(process.pid)
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
|
pass
|
||||||
|
return pids
|
||||||
|
|
||||||
|
|
||||||
|
class CreateReportRequest(BaseModel):
|
||||||
|
test: str
|
||||||
|
test_run_id: str
|
||||||
|
# category: Optional[str] = []
|
||||||
|
mock: Optional[bool] = False
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
updates_list = []
|
||||||
|
|
||||||
|
origins = [
|
||||||
|
"http://localhost:8000",
|
||||||
|
"http://localhost:8080",
|
||||||
|
"http://127.0.0.1:5000",
|
||||||
|
"http://localhost:5000",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def stream_output(pipe):
|
||||||
|
for line in pipe:
|
||||||
|
print(line, end="")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||||
|
from agbenchmark.agent_api_interface import upload_artifacts
|
||||||
|
from agbenchmark.challenges import get_challenge_from_source_uri
|
||||||
|
from agbenchmark.main import run_benchmark
|
||||||
|
|
||||||
|
configuration = Configuration(
|
||||||
|
host=agbenchmark_config.host or "http://localhost:8000"
|
||||||
|
)
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/reports")
|
||||||
|
def run_single_test(body: CreateReportRequest) -> dict:
|
||||||
|
pids = find_agbenchmark_without_uvicorn()
|
||||||
|
logger.info(f"pids already running with agbenchmark: {pids}")
|
||||||
|
|
||||||
|
logger.debug(f"Request to /reports: {body.model_dump()}")
|
||||||
|
|
||||||
|
# Start the benchmark in a separate thread
|
||||||
|
benchmark_process = Process(
|
||||||
|
target=lambda: run_benchmark(
|
||||||
|
config=agbenchmark_config,
|
||||||
|
tests=(body.test,),
|
||||||
|
mock=body.mock or False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
benchmark_process.start()
|
||||||
|
|
||||||
|
# Wait for the benchmark to finish, with a timeout of 200 seconds
|
||||||
|
timeout = 200
|
||||||
|
start_time = time.time()
|
||||||
|
while benchmark_process.is_alive():
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
logger.warning(f"Benchmark run timed out after {timeout} seconds")
|
||||||
|
benchmark_process.terminate()
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
||||||
|
|
||||||
|
# List all folders in the current working directory
|
||||||
|
reports_folder = agbenchmark_config.reports_folder
|
||||||
|
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
|
||||||
|
|
||||||
|
# Sort the folders based on their names
|
||||||
|
sorted_folders = sorted(folders, key=lambda x: x.name)
|
||||||
|
|
||||||
|
# Get the last folder
|
||||||
|
latest_folder = sorted_folders[-1] if sorted_folders else None
|
||||||
|
|
||||||
|
# Read report.json from this folder
|
||||||
|
if latest_folder:
|
||||||
|
report_path = latest_folder / "report.json"
|
||||||
|
logger.debug(f"Getting latest report from {report_path}")
|
||||||
|
if report_path.exists():
|
||||||
|
with report_path.open() as file:
|
||||||
|
data = json.load(file)
|
||||||
|
logger.debug(f"Report data: {data}")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
502,
|
||||||
|
"Could not get result after running benchmark: "
|
||||||
|
f"'report.json' does not exist in '{latest_folder}'",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
504, "Could not get result after running benchmark: no reports found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@router.post("/agent/tasks", tags=["agent"])
|
||||||
|
async def create_agent_task(task_eval_request: TaskEvalRequestBody) -> Task:
|
||||||
|
"""
|
||||||
|
Creates a new task using the provided TaskEvalRequestBody and returns a Task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_eval_request: `TaskRequestBody` including an eval_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task: A new task with task_id, input, additional_input,
|
||||||
|
and empty lists for artifacts and steps.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Request (TaskEvalRequestBody defined in schema.py):
|
||||||
|
{
|
||||||
|
...,
|
||||||
|
"eval_id": "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||||
|
}
|
||||||
|
|
||||||
|
Response (Task defined in `agent_protocol_client.models`):
|
||||||
|
{
|
||||||
|
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||||
|
"input": "Write the word 'Washington' to a .txt file",
|
||||||
|
"artifacts": []
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
challenge_info = CHALLENGES[task_eval_request.eval_id]
|
||||||
|
async with ApiClient(configuration) as api_client:
|
||||||
|
api_instance = AgentApi(api_client)
|
||||||
|
task_input = challenge_info.task
|
||||||
|
|
||||||
|
task_request_body = TaskRequestBody(
|
||||||
|
input=task_input, additional_input=None
|
||||||
|
)
|
||||||
|
task_response = await api_instance.create_agent_task(
|
||||||
|
task_request_body=task_request_body
|
||||||
|
)
|
||||||
|
task_info = BenchmarkTaskInfo(
|
||||||
|
task_id=task_response.task_id,
|
||||||
|
start_time=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
challenge_info=challenge_info,
|
||||||
|
)
|
||||||
|
task_informations[task_info.task_id] = task_info
|
||||||
|
|
||||||
|
if input_artifacts_dir := challenge_info.task_artifacts_dir:
|
||||||
|
await upload_artifacts(
|
||||||
|
api_instance,
|
||||||
|
input_artifacts_dir,
|
||||||
|
task_response.task_id,
|
||||||
|
"artifacts_in",
|
||||||
|
)
|
||||||
|
return task_response
|
||||||
|
except ApiException as e:
|
||||||
|
logger.error(f"Error whilst trying to create a task:\n{e}")
|
||||||
|
logger.error(
|
||||||
|
"The above error was caused while processing request: "
|
||||||
|
f"{task_eval_request}"
|
||||||
|
)
|
||||||
|
raise HTTPException(500)
|
||||||
|
|
||||||
|
@router.post("/agent/tasks/{task_id}/steps")
|
||||||
|
async def proxy(request: Request, task_id: str):
|
||||||
|
timeout = httpx.Timeout(300.0, read=300.0) # 5 minutes
|
||||||
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
|
# Construct the new URL
|
||||||
|
new_url = f"{configuration.host}/ap/v1/agent/tasks/{task_id}/steps"
|
||||||
|
|
||||||
|
# Forward the request
|
||||||
|
response = await client.post(
|
||||||
|
new_url,
|
||||||
|
content=await request.body(),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the response from the forwarded request
|
||||||
|
return Response(content=response.content, status_code=response.status_code)
|
||||||
|
|
||||||
|
@router.post("/agent/tasks/{task_id}/evaluations")
|
||||||
|
async def create_evaluation(task_id: str) -> BenchmarkRun:
|
||||||
|
task_info = task_informations[task_id]
|
||||||
|
challenge = get_challenge_from_source_uri(task_info.challenge_info.source_uri)
|
||||||
|
try:
|
||||||
|
async with ApiClient(configuration) as api_client:
|
||||||
|
api_instance = AgentApi(api_client)
|
||||||
|
eval_results = await challenge.evaluate_task_state(
|
||||||
|
api_instance, task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_info = BenchmarkRun(
|
||||||
|
repository_info=RepositoryInfo(),
|
||||||
|
run_details=RunDetails(
|
||||||
|
command=f"agbenchmark --test={challenge.info.name}",
|
||||||
|
benchmark_start_time=(
|
||||||
|
task_info.start_time.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||||
|
),
|
||||||
|
test_name=challenge.info.name,
|
||||||
|
),
|
||||||
|
task_info=TaskInfo(
|
||||||
|
data_path=challenge.info.source_uri,
|
||||||
|
is_regression=None,
|
||||||
|
category=[c.value for c in challenge.info.category],
|
||||||
|
task=challenge.info.task,
|
||||||
|
answer=challenge.info.reference_answer or "",
|
||||||
|
description=challenge.info.description or "",
|
||||||
|
),
|
||||||
|
metrics=Metrics(
|
||||||
|
success=all(e.passed for e in eval_results),
|
||||||
|
success_percentage=(
|
||||||
|
100 * sum(e.score for e in eval_results) / len(eval_results)
|
||||||
|
if eval_results # avoid division by 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
attempted=True,
|
||||||
|
),
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Returning evaluation data:\n{eval_info.model_dump_json(indent=4)}"
|
||||||
|
)
|
||||||
|
return eval_info
|
||||||
|
except ApiException as e:
|
||||||
|
logger.error(f"Error {e} whilst trying to evaluate task: {task_id}")
|
||||||
|
raise HTTPException(500)
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/ap/v1")
|
||||||
|
|
||||||
|
return app
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user