mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-21 04:57:58 -05:00
Compare commits
60 Commits
feat/sensi
...
make-old-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
326554d89a | ||
|
|
5e22a1888a | ||
|
|
a4d7b0142f | ||
|
|
7d6375f59c | ||
|
|
aeec0ce509 | ||
|
|
b32bfcaac5 | ||
|
|
5373a6eb6e | ||
|
|
98cde46ccb | ||
|
|
bd10da10d9 | ||
|
|
60fdee1345 | ||
|
|
6f2783468c | ||
|
|
c1031b286d | ||
|
|
b849eafb7f | ||
|
|
572c3f5e0d | ||
|
|
89003a585d | ||
|
|
0e65785228 | ||
|
|
f07dff1cdd | ||
|
|
00e02a4696 | ||
|
|
634bff8277 | ||
|
|
d591f36c7b | ||
|
|
a347bed0b1 | ||
|
|
4eeb6ee2b0 | ||
|
|
7db962b9f9 | ||
|
|
9108b21541 | ||
|
|
ffe9325296 | ||
|
|
0a616d9267 | ||
|
|
ab95077e5b | ||
|
|
e477150979 | ||
|
|
804430e243 | ||
|
|
acb320d32d | ||
|
|
32f68d5999 | ||
|
|
49f56b4e8d | ||
|
|
bead811e73 | ||
|
|
013f728ebf | ||
|
|
cda9572acd | ||
|
|
e0784f8f6b | ||
|
|
3040f39136 | ||
|
|
515504c604 | ||
|
|
18edeaeaf4 | ||
|
|
44182aff9c | ||
|
|
864c5a7846 | ||
|
|
699fffb1a8 | ||
|
|
f0641c2d26 | ||
|
|
94b6f74c95 | ||
|
|
46aabab3ea | ||
|
|
0a65df5102 | ||
|
|
6fbd208fe3 | ||
|
|
8fc174ca87 | ||
|
|
cacc89790f | ||
|
|
b9113bee02 | ||
|
|
3f65da03e7 | ||
|
|
9e96d11b2d | ||
|
|
4c264b7ae9 | ||
|
|
0adbc0bd05 | ||
|
|
8f3291bc92 | ||
|
|
7a20de880d | ||
|
|
ef8a6d2528 | ||
|
|
fd66be2aaa | ||
|
|
ae2cc97dc4 | ||
|
|
ea521eed26 |
73
.github/workflows/classic-autogpt-ci.yml
vendored
73
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +23,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +50,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +77,12 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +96,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
36
.github/workflows/classic-autogpts-ci.yml
vendored
36
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,9 +11,6 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -22,9 +19,6 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +29,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +45,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
189
.github/workflows/classic-benchmark-ci.yml
vendored
189
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,17 +1,21 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
|
||||
concurrency:
|
||||
@@ -23,23 +27,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -47,71 +44,84 @@ jobs:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
self-test-with-agent:
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -126,51 +136,22 @@ jobs:
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
182
.github/workflows/classic-forge-ci.yml
vendored
182
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,11 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,115 +19,38 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -140,12 +61,15 @@ jobs:
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +83,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -159,6 +160,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,5 +182,8 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
.claude/settings.local.json
|
||||
**/.claude/settings.local.json
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -43,29 +43,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt install
|
||||
# include forge source (since it's a path dependency)
|
||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: poetry -C classic/forge install
|
||||
files: ^classic/forge/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: poetry -C classic/benchmark install
|
||||
files: ^classic/benchmark/poetry\.lock$
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: poetry -C classic install
|
||||
files: ^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -116,26 +97,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -149,26 +114,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -204,29 +156,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
@@ -218,7 +218,6 @@ async def save_agent_to_library(
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
sensitive_action_safe_mode=True,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from .models import (
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import (
|
||||
build_missing_credentials_from_graph,
|
||||
check_user_has_required_credentials,
|
||||
extract_credentials_from_schema,
|
||||
fetch_graph_from_store_slug,
|
||||
get_or_create_library_agent,
|
||||
@@ -237,13 +237,15 @@ class RunAgentTool(BaseTool):
|
||||
# Return credentials needed response with input data info
|
||||
# The UI handles credential setup automatically, so the message
|
||||
# focuses on asking about input data
|
||||
requirements_creds_dict = build_missing_credentials_from_graph(
|
||||
graph, None
|
||||
credentials = extract_credentials_from_schema(
|
||||
graph.credentials_input_schema
|
||||
)
|
||||
missing_credentials_dict = build_missing_credentials_from_graph(
|
||||
graph, graph_credentials
|
||||
missing_creds_check = await check_user_has_required_credentials(
|
||||
user_id, credentials
|
||||
)
|
||||
requirements_creds_list = list(requirements_creds_dict.values())
|
||||
missing_credentials_dict = {
|
||||
c.id: c.model_dump() for c in missing_creds_check
|
||||
}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=self._build_inputs_message(graph, MSG_WHAT_VALUES_TO_USE),
|
||||
@@ -257,7 +259,7 @@ class RunAgentTool(BaseTool):
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": requirements_creds_list,
|
||||
"credentials": [c.model_dump() for c in credentials],
|
||||
"inputs": self._get_inputs_list(graph.input_schema),
|
||||
"execution_modes": self._get_execution_modes(graph),
|
||||
},
|
||||
|
||||
@@ -22,7 +22,6 @@ from .models import (
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
from .utils import build_missing_credentials_from_field_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -190,11 +189,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
missing_creds_dict = build_missing_credentials_from_field_info(
|
||||
credentials_fields_info, set(matched_credentials.keys())
|
||||
)
|
||||
missing_creds_list = list(missing_creds_dict.values())
|
||||
missing_creds_dict = {c.id: c.model_dump() for c in missing_credentials}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=(
|
||||
@@ -211,7 +206,7 @@ class RunBlockTool(BaseTool):
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": missing_creds_list,
|
||||
"credentials": [c.model_dump() for c in missing_credentials],
|
||||
"inputs": self._get_inputs_list(block),
|
||||
"execution_modes": ["immediate"],
|
||||
},
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -89,59 +89,6 @@ def extract_credentials_from_schema(
|
||||
return credentials
|
||||
|
||||
|
||||
def _serialize_missing_credential(
|
||||
field_key: str, field_info: CredentialsFieldInfo
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert credential field info into a serializable dict that preserves all supported
|
||||
credential types (e.g., api_key + oauth2) so the UI can offer multiple options.
|
||||
"""
|
||||
supported_types = sorted(field_info.supported_types)
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
scopes = sorted(field_info.required_scopes or [])
|
||||
|
||||
return {
|
||||
"id": field_key,
|
||||
"title": field_key.replace("_", " ").title(),
|
||||
"provider": provider,
|
||||
"provider_name": provider.replace("_", " ").title(),
|
||||
"type": supported_types[0] if supported_types else "api_key",
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
|
||||
|
||||
def build_missing_credentials_from_graph(
|
||||
graph: GraphModel, matched_credentials: dict[str, CredentialsMetaInput] | None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build a missing_credentials mapping from a graph's aggregated credentials inputs,
|
||||
preserving all supported credential types for each field.
|
||||
"""
|
||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||
aggregated_fields = graph.aggregate_credentials_inputs()
|
||||
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
||||
if field_key not in matched_keys
|
||||
}
|
||||
|
||||
|
||||
def build_missing_credentials_from_field_info(
|
||||
credential_fields: dict[str, CredentialsFieldInfo],
|
||||
matched_keys: set[str],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build missing_credentials mapping from a simple credentials field info dictionary.
|
||||
"""
|
||||
return {
|
||||
field_key: _serialize_missing_credential(field_key, field_info)
|
||||
for field_key, field_info in credential_fields.items()
|
||||
if field_key not in matched_keys
|
||||
}
|
||||
|
||||
|
||||
def extract_credentials_as_dict(
|
||||
credentials_input_schema: dict[str, Any] | None,
|
||||
) -> dict[str, CredentialsMetaInput]:
|
||||
|
||||
@@ -41,7 +41,6 @@ class PendingHumanReviewModel(BaseModel):
|
||||
graph_exec_id: str = Field(description="Graph execution ID")
|
||||
graph_id: str = Field(description="Graph ID")
|
||||
graph_version: int = Field(description="Graph version")
|
||||
node_id: str = Field(description="Node ID in the graph definition")
|
||||
payload: SafeJsonData = Field(description="The actual data payload awaiting review")
|
||||
instructions: str | None = Field(
|
||||
description="Instructions or message for the reviewer", default=None
|
||||
@@ -82,7 +81,6 @@ class PendingHumanReviewModel(BaseModel):
|
||||
graph_exec_id=review.graphExecId,
|
||||
graph_id=review.graphId,
|
||||
graph_version=review.graphVersion,
|
||||
node_id=review.nodeId,
|
||||
payload=review.payload,
|
||||
instructions=review.instructions,
|
||||
editable=review.editable,
|
||||
@@ -181,15 +179,6 @@ class ReviewRequest(BaseModel):
|
||||
reviews: List[ReviewItem] = Field(
|
||||
description="All reviews with their approval status, data, and messages"
|
||||
)
|
||||
auto_approve_node_ids: List[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of node IDs (from the graph definition) to auto-approve for "
|
||||
"the remainder of this execution. Future reviews from these specific "
|
||||
"nodes will be automatically approved. This only affects the current "
|
||||
"execution run."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_review_completeness(self):
|
||||
|
||||
@@ -41,7 +41,6 @@ def sample_pending_review(test_user_id: str) -> PendingHumanReviewModel:
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
payload={"data": "test payload", "value": 42},
|
||||
instructions="Please review this data",
|
||||
editable=True,
|
||||
@@ -161,7 +160,6 @@ def test_process_review_action_approve_success(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
payload={"data": "modified payload", "value": 50},
|
||||
instructions="Please review this data",
|
||||
editable=True,
|
||||
@@ -225,7 +223,6 @@ def test_process_review_action_reject_success(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
payload={"data": "test payload"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
@@ -277,7 +274,6 @@ def test_process_review_action_mixed_success(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_456",
|
||||
payload={"data": "second payload"},
|
||||
instructions="Second review",
|
||||
editable=False,
|
||||
@@ -307,7 +303,6 @@ def test_process_review_action_mixed_success(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
payload={"data": "modified"},
|
||||
instructions="Please review",
|
||||
editable=True,
|
||||
@@ -326,7 +321,6 @@ def test_process_review_action_mixed_success(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_456",
|
||||
payload={"data": "second payload"},
|
||||
instructions="Second review",
|
||||
editable=False,
|
||||
|
||||
@@ -5,7 +5,7 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, HTTPException, Query, Security, status
|
||||
from prisma.enums import ReviewStatus
|
||||
|
||||
from backend.data.execution import ExecutionContext, get_graph_execution_meta
|
||||
from backend.data.execution import get_graph_execution_meta
|
||||
from backend.data.human_review import (
|
||||
get_pending_reviews_for_execution,
|
||||
get_pending_reviews_for_user,
|
||||
@@ -169,23 +169,10 @@ async def process_review_action(
|
||||
if not still_has_pending:
|
||||
# Resume execution
|
||||
try:
|
||||
# If auto_approve_node_ids is set, create a context that will
|
||||
# automatically approve future reviews from these specific nodes
|
||||
execution_context = None
|
||||
if request.auto_approve_node_ids:
|
||||
execution_context = ExecutionContext(
|
||||
auto_approved_node_ids=set(request.auto_approve_node_ids),
|
||||
)
|
||||
logger.info(
|
||||
f"Auto-approving future reviews for nodes "
|
||||
f"{request.auto_approve_node_ids} in execution {graph_exec_id}"
|
||||
)
|
||||
|
||||
await add_graph_execution(
|
||||
graph_id=first_review.graph_id,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
logger.info(f"Resumed execution {graph_exec_id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -401,11 +401,27 @@ async def add_generated_agent_image(
|
||||
)
|
||||
|
||||
|
||||
def _initialize_graph_settings(graph: graph_db.GraphModel) -> GraphSettings:
|
||||
"""
|
||||
Initialize GraphSettings based on graph content.
|
||||
|
||||
Args:
|
||||
graph: The graph to analyze
|
||||
|
||||
Returns:
|
||||
GraphSettings with appropriate human_in_the_loop_safe_mode value
|
||||
"""
|
||||
if graph.has_human_in_the_loop:
|
||||
# Graph has HITL blocks - set safe mode to True by default
|
||||
return GraphSettings(human_in_the_loop_safe_mode=True)
|
||||
else:
|
||||
# Graph has no HITL blocks - keep None
|
||||
return GraphSettings(human_in_the_loop_safe_mode=None)
|
||||
|
||||
|
||||
async def create_library_agent(
|
||||
graph: graph_db.GraphModel,
|
||||
user_id: str,
|
||||
hitl_safe_mode: bool = True,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
create_library_agents_for_sub_graphs: bool = True,
|
||||
) -> list[library_model.LibraryAgent]:
|
||||
"""
|
||||
@@ -414,8 +430,6 @@ async def create_library_agent(
|
||||
Args:
|
||||
agent: The agent/Graph to add to the library.
|
||||
user_id: The user to whom the agent will be added.
|
||||
hitl_safe_mode: Whether HITL blocks require manual review (default True).
|
||||
sensitive_action_safe_mode: Whether sensitive action blocks require review.
|
||||
create_library_agents_for_sub_graphs: If True, creates LibraryAgent records for sub-graphs as well.
|
||||
|
||||
Returns:
|
||||
@@ -451,11 +465,7 @@ async def create_library_agent(
|
||||
}
|
||||
},
|
||||
settings=SafeJson(
|
||||
GraphSettings.from_graph(
|
||||
graph_entry,
|
||||
hitl_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
).model_dump()
|
||||
_initialize_graph_settings(graph_entry).model_dump()
|
||||
),
|
||||
),
|
||||
include=library_agent_include(
|
||||
@@ -617,6 +627,33 @@ async def update_library_agent(
|
||||
raise DatabaseError("Failed to update library agent") from e
|
||||
|
||||
|
||||
async def update_library_agent_settings(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
settings: GraphSettings,
|
||||
) -> library_model.LibraryAgent:
|
||||
"""
|
||||
Updates the settings for a specific LibraryAgent.
|
||||
|
||||
Args:
|
||||
user_id: The owner of the LibraryAgent.
|
||||
agent_id: The ID of the LibraryAgent to update.
|
||||
settings: New GraphSettings to apply.
|
||||
|
||||
Returns:
|
||||
The updated LibraryAgent.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the specified LibraryAgent does not exist.
|
||||
DatabaseError: If there's an error in the update operation.
|
||||
"""
|
||||
return await update_library_agent(
|
||||
library_agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
|
||||
async def delete_library_agent(
|
||||
library_agent_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
@@ -801,7 +838,7 @@ async def add_store_agent_to_library(
|
||||
"isCreatedByUser": False,
|
||||
"useGraphIsActiveVersion": False,
|
||||
"settings": SafeJson(
|
||||
GraphSettings.from_graph(graph_model).model_dump()
|
||||
_initialize_graph_settings(graph_model).model_dump()
|
||||
),
|
||||
},
|
||||
include=library_agent_include(
|
||||
@@ -1191,15 +1228,8 @@ async def fork_library_agent(
|
||||
)
|
||||
new_graph = await on_graph_activate(new_graph, user_id=user_id)
|
||||
|
||||
# Create a library agent for the new graph, preserving safe mode settings
|
||||
return (
|
||||
await create_library_agent(
|
||||
new_graph,
|
||||
user_id,
|
||||
hitl_safe_mode=original_agent.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=original_agent.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
)[0]
|
||||
# Create a library agent for the new graph
|
||||
return (await create_library_agent(new_graph, user_id))[0]
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error cloning library agent: {e}")
|
||||
raise DatabaseError("Failed to fork library agent") from e
|
||||
|
||||
@@ -73,12 +73,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_external_trigger: bool = pydantic.Field(
|
||||
description="Whether the agent has an external trigger (e.g. webhook) node"
|
||||
)
|
||||
has_human_in_the_loop: bool = pydantic.Field(
|
||||
description="Whether the agent has human-in-the-loop blocks"
|
||||
)
|
||||
has_sensitive_action: bool = pydantic.Field(
|
||||
description="Whether the agent has sensitive action blocks"
|
||||
)
|
||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
@@ -186,8 +180,6 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
graph.credentials_input_schema if sub_graphs is not None else None
|
||||
),
|
||||
has_external_trigger=graph.has_external_trigger,
|
||||
has_human_in_the_loop=graph.has_human_in_the_loop,
|
||||
has_sensitive_action=graph.has_sensitive_action,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
new_output=new_output,
|
||||
can_access_graph=can_access_graph,
|
||||
|
||||
@@ -52,8 +52,6 @@ async def test_get_library_agents_success(
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
has_human_in_the_loop=False,
|
||||
has_sensitive_action=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
@@ -77,8 +75,6 @@ async def test_get_library_agents_success(
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
has_human_in_the_loop=False,
|
||||
has_sensitive_action=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
@@ -154,8 +150,6 @@ async def test_get_favorite_library_agents_success(
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
has_human_in_the_loop=False,
|
||||
has_sensitive_action=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
recommended_schedule_cron=None,
|
||||
new_output=False,
|
||||
@@ -224,8 +218,6 @@ def test_add_agent_to_library_success(
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
credentials_input_schema={"type": "object", "properties": {}},
|
||||
has_external_trigger=False,
|
||||
has_human_in_the_loop=False,
|
||||
has_sensitive_action=False,
|
||||
status=library_model.LibraryAgentStatus.COMPLETED,
|
||||
new_output=False,
|
||||
can_access_graph=True,
|
||||
|
||||
@@ -154,16 +154,15 @@ async def store_content_embedding(
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
# Use {pgvector_schema}.vector for explicit pgvector type qualification
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::{pgvector_schema}.vector, $5, $6::jsonb, NOW(), NOW())
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::{pgvector_schema}.vector,
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
@@ -178,6 +177,7 @@ async def store_content_embedding(
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
@@ -236,6 +236,7 @@ async def get_content_embedding(
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
@@ -870,46 +871,31 @@ async def semantic_search(
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params) + 1
|
||||
content_type_placeholders = ", ".join(
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params.extend([ct.value for ct in content_types])
|
||||
|
||||
# Build min_similarity param index before appending
|
||||
min_similarity_idx = len(params) + 1
|
||||
params.append(min_similarity)
|
||||
|
||||
# Use regular string (not f-string) for template to preserve {schema_prefix} and {schema} placeholders
|
||||
# Use OPERATOR({pgvector_schema}.<=>) for explicit operator schema qualification
|
||||
sql = (
|
||||
"""
|
||||
sql = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
|
||||
+ embedding_str
|
||||
+ """'::{pgvector_schema}.vector) as similarity
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND 1 - (embedding OPERATOR({pgvector_schema}.<=>) '"""
|
||||
+ embedding_str
|
||||
+ """'::{pgvector_schema}.vector) >= $"""
|
||||
+ str(min_similarity_idx)
|
||||
+ """
|
||||
1 - (embedding <=> '{embedding_str}'::vector) as similarity
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders})
|
||||
{user_filter}
|
||||
AND 1 - (embedding <=> '{embedding_str}'::vector) >= ${len(params) + 1}
|
||||
ORDER BY similarity DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
)
|
||||
params.append(min_similarity)
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(sql, *params)
|
||||
results = await query_raw_with_schema(
|
||||
sql, *params, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
@@ -936,41 +922,31 @@ async def semantic_search(
|
||||
# Add content type parameters and build placeholders dynamically
|
||||
content_type_start_idx = len(params_lexical) + 1
|
||||
content_type_placeholders_lexical = ", ".join(
|
||||
"$" + str(content_type_start_idx + i) + '::{schema_prefix}"ContentType"'
|
||||
f'${content_type_start_idx + i}::{{{{schema_prefix}}}}"ContentType"'
|
||||
for i in range(len(content_types))
|
||||
)
|
||||
params_lexical.extend([ct.value for ct in content_types])
|
||||
|
||||
# Build query param index before appending
|
||||
query_param_idx = len(params_lexical) + 1
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
# Use regular string (not f-string) for template to preserve {schema_prefix} placeholders
|
||||
sql_lexical = (
|
||||
"""
|
||||
sql_lexical = f"""
|
||||
SELECT
|
||||
"contentId" as content_id,
|
||||
"contentType" as content_type,
|
||||
"searchableText" as searchable_text,
|
||||
metadata,
|
||||
0.0 as similarity
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ("""
|
||||
+ content_type_placeholders_lexical
|
||||
+ """)
|
||||
"""
|
||||
+ user_filter
|
||||
+ """
|
||||
AND "searchableText" ILIKE $"""
|
||||
+ str(query_param_idx)
|
||||
+ """
|
||||
FROM {{{{schema_prefix}}}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" IN ({content_type_placeholders_lexical})
|
||||
{user_filter}
|
||||
AND "searchableText" ILIKE ${len(params_lexical) + 1}
|
||||
ORDER BY "updatedAt" DESC
|
||||
LIMIT $1
|
||||
"""
|
||||
)
|
||||
params_lexical.append(f"%{query}%")
|
||||
|
||||
try:
|
||||
results = await query_raw_with_schema(sql_lexical, *params_lexical)
|
||||
results = await query_raw_with_schema(
|
||||
sql_lexical, *params_lexical, set_public_search_path=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"content_id": row["content_id"],
|
||||
|
||||
@@ -155,14 +155,18 @@ async def test_store_embedding_success(mocker):
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called once for INSERT (no separate SET search_path needed)
|
||||
assert mock_client.execute_raw.call_count == 1
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
|
||||
# Verify the INSERT query with the actual data
|
||||
call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "test-version-id" in call_args
|
||||
assert "[0.1,0.2,0.3]" in call_args
|
||||
assert None in call_args # userId should be None for store agents
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
|
||||
@@ -12,7 +12,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from rank_bm25 import BM25Okapi # type: ignore[import-untyped]
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
EMBEDDING_DIM,
|
||||
@@ -295,7 +295,7 @@ async def unified_hybrid_search(
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
WHERE uce."contentType" = ANY({content_types_param}::{{schema_prefix}}"ContentType"[])
|
||||
{user_filter}
|
||||
ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
)
|
||||
),
|
||||
@@ -307,7 +307,7 @@ async def unified_hybrid_search(
|
||||
uce.metadata,
|
||||
uce."updatedAt" as updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match from metadata
|
||||
@@ -363,7 +363,9 @@ async def unified_hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
# Apply BM25 reranking
|
||||
@@ -583,7 +585,7 @@ async def hybrid_search(
|
||||
WHERE uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
AND uce."userId" IS NULL
|
||||
AND {where_clause}
|
||||
ORDER BY uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) uce
|
||||
),
|
||||
@@ -605,7 +607,7 @@ async def hybrid_search(
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
COALESCE(1 - (uce.embedding OPERATOR({{pgvector_schema}}.<=>) {embedding_param}::{{pgvector_schema}}.vector), 0) as semantic_score,
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score (raw, will normalize)
|
||||
COALESCE(ts_rank_cd(uce.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match
|
||||
@@ -686,7 +688,9 @@ async def hybrid_search(
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
results = await query_raw_with_schema(sql_query, *params)
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
|
||||
@@ -761,8 +761,10 @@ async def create_new_graph(
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
# The return value of the create graph & library function is intentionally not used here,
|
||||
# as the graph already valid and no sub-graphs are returned back.
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await library_db.create_library_agent(graph, user_id)
|
||||
await library_db.create_library_agent(graph, user_id=user_id)
|
||||
activated_graph = await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
if create_graph.source == "builder":
|
||||
@@ -886,19 +888,21 @@ async def set_graph_active_version(
|
||||
async def _update_library_agent_version_and_settings(
|
||||
user_id: str, agent_graph: graph_db.GraphModel
|
||||
) -> library_model.LibraryAgent:
|
||||
# Keep the library agent up to date with the new active version
|
||||
library = await library_db.update_agent_version_in_library(
|
||||
user_id, agent_graph.id, agent_graph.version
|
||||
)
|
||||
updated_settings = GraphSettings.from_graph(
|
||||
graph=agent_graph,
|
||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||
)
|
||||
if updated_settings != library.settings:
|
||||
library = await library_db.update_library_agent(
|
||||
library_agent_id=library.id,
|
||||
# If the graph has HITL node, initialize the setting if it's not already set.
|
||||
if (
|
||||
agent_graph.has_human_in_the_loop
|
||||
and library.settings.human_in_the_loop_safe_mode is None
|
||||
):
|
||||
await library_db.update_library_agent_settings(
|
||||
user_id=user_id,
|
||||
settings=updated_settings,
|
||||
agent_id=library.id,
|
||||
settings=library.settings.model_copy(
|
||||
update={"human_in_the_loop_safe_mode": True}
|
||||
),
|
||||
)
|
||||
return library
|
||||
|
||||
@@ -915,18 +919,21 @@ async def update_graph_settings(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> GraphSettings:
|
||||
"""Update graph settings for the user's library agent."""
|
||||
# Get the library agent for this graph
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
if not library_agent:
|
||||
raise HTTPException(404, f"Graph #{graph_id} not found in user's library")
|
||||
|
||||
updated_agent = await library_db.update_library_agent(
|
||||
library_agent_id=library_agent.id,
|
||||
# Update the library agent settings
|
||||
updated_agent = await library_db.update_library_agent_settings(
|
||||
user_id=user_id,
|
||||
agent_id=library_agent.id,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Return the updated settings
|
||||
return GraphSettings.model_validate(updated_agent.settings)
|
||||
|
||||
|
||||
|
||||
@@ -680,58 +680,3 @@ class ListIsEmptyBlock(Block):
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "is_empty", len(input_data.list) == 0
|
||||
|
||||
|
||||
class ConcatenateListsBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
lists: List[List[Any]] = SchemaField(
|
||||
description="A list of lists to concatenate together. All lists will be combined in order into a single list.",
|
||||
placeholder="e.g., [[1, 2], [3, 4], [5, 6]]",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
concatenated_list: List[Any] = SchemaField(
|
||||
description="The concatenated list containing all elements from all input lists in order."
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if concatenation failed due to invalid input types."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3cf9298b-5817-4141-9d80-7c2cc5199c8e",
|
||||
description="Concatenates multiple lists into a single list. All elements from all input lists are combined in order.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ConcatenateListsBlock.Input,
|
||||
output_schema=ConcatenateListsBlock.Output,
|
||||
test_input=[
|
||||
{"lists": [[1, 2, 3], [4, 5, 6]]},
|
||||
{"lists": [["a", "b"], ["c"], ["d", "e", "f"]]},
|
||||
{"lists": [[1, 2], []]},
|
||||
{"lists": []},
|
||||
],
|
||||
test_output=[
|
||||
("concatenated_list", [1, 2, 3, 4, 5, 6]),
|
||||
("concatenated_list", ["a", "b", "c", "d", "e", "f"]),
|
||||
("concatenated_list", [1, 2]),
|
||||
("concatenated_list", []),
|
||||
],
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
concatenated = []
|
||||
for idx, lst in enumerate(input_data.lists):
|
||||
if lst is None:
|
||||
# Skip None values to avoid errors
|
||||
continue
|
||||
if not isinstance(lst, list):
|
||||
# Type validation: each item must be a list
|
||||
# Strings are iterable and would cause extend() to iterate character-by-character
|
||||
# Non-iterable types would raise TypeError
|
||||
yield "error", (
|
||||
f"Invalid input at index {idx}: expected a list, got {type(lst).__name__}. "
|
||||
f"All items in 'lists' must be lists (e.g., [[1, 2], [3, 4]])."
|
||||
)
|
||||
return
|
||||
concatenated.extend(lst)
|
||||
yield "concatenated_list", concatenated
|
||||
|
||||
@@ -55,7 +55,6 @@ class HITLReviewHelper:
|
||||
async def _handle_review_request(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -63,7 +62,6 @@ class HITLReviewHelper:
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
skip_safe_mode_check: bool = False,
|
||||
) -> Optional[ReviewResult]:
|
||||
"""
|
||||
Handle a review request for a block that requires human review.
|
||||
@@ -71,7 +69,6 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
@@ -79,8 +76,6 @@ class HITLReviewHelper:
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
skip_safe_mode_check: If True, skip the safe mode check (caller already
|
||||
verified). Used by sensitive action blocks that check their own flag.
|
||||
|
||||
Returns:
|
||||
ReviewResult if review is complete, None if waiting for human input
|
||||
@@ -89,11 +84,7 @@ class HITLReviewHelper:
|
||||
Exception: If review creation or status update fails
|
||||
"""
|
||||
# Skip review if safe mode is disabled - return auto-approved result
|
||||
# (unless caller already checked and wants to skip this check)
|
||||
if (
|
||||
not skip_safe_mode_check
|
||||
and not execution_context.human_in_the_loop_safe_mode
|
||||
):
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
@@ -105,27 +96,12 @@ class HITLReviewHelper:
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
# Skip review if this specific node has been auto-approved by the user
|
||||
if node_id in execution_context.auto_approved_node_ids:
|
||||
logger.info(
|
||||
f"Block {block_name} skipping review for node {node_exec_id} - "
|
||||
f"node {node_id} is auto-approved"
|
||||
)
|
||||
return ReviewResult(
|
||||
data=input_data,
|
||||
status=ReviewStatus.APPROVED,
|
||||
message="Auto-approved (user approved all future actions for this block)",
|
||||
processed=True,
|
||||
node_exec_id=node_exec_id,
|
||||
)
|
||||
|
||||
result = await HITLReviewHelper.get_or_create_human_review(
|
||||
user_id=user_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
node_id=node_id,
|
||||
input_data=input_data,
|
||||
message=f"Review required for {block_name} execution",
|
||||
editable=editable,
|
||||
@@ -153,7 +129,6 @@ class HITLReviewHelper:
|
||||
async def handle_review_decision(
|
||||
input_data: Any,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -161,7 +136,6 @@ class HITLReviewHelper:
|
||||
execution_context: ExecutionContext,
|
||||
block_name: str = "Block",
|
||||
editable: bool = False,
|
||||
skip_safe_mode_check: bool = False,
|
||||
) -> Optional[ReviewDecision]:
|
||||
"""
|
||||
Handle a review request and return the decision in a single call.
|
||||
@@ -169,7 +143,6 @@ class HITLReviewHelper:
|
||||
Args:
|
||||
input_data: The input data to be reviewed
|
||||
user_id: ID of the user requesting the review
|
||||
node_id: ID of the node in the graph definition
|
||||
node_exec_id: ID of the node execution
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph
|
||||
@@ -177,8 +150,6 @@ class HITLReviewHelper:
|
||||
execution_context: Current execution context
|
||||
block_name: Name of the block requesting review
|
||||
editable: Whether the reviewer can edit the data
|
||||
skip_safe_mode_check: If True, skip the safe mode check (caller already
|
||||
verified). Used by sensitive action blocks that check their own flag.
|
||||
|
||||
Returns:
|
||||
ReviewDecision if review is complete (approved/rejected),
|
||||
@@ -187,7 +158,6 @@ class HITLReviewHelper:
|
||||
review_result = await HITLReviewHelper._handle_review_request(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
@@ -195,7 +165,6 @@ class HITLReviewHelper:
|
||||
execution_context=execution_context,
|
||||
block_name=block_name,
|
||||
editable=editable,
|
||||
skip_safe_mode_check=skip_safe_mode_check,
|
||||
)
|
||||
|
||||
if review_result is None:
|
||||
|
||||
@@ -97,7 +97,6 @@ class HumanInTheLoopBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -105,17 +104,7 @@ class HumanInTheLoopBlock(Block):
|
||||
execution_context: ExecutionContext,
|
||||
**_kwargs,
|
||||
) -> BlockOutput:
|
||||
# Check if this specific node has been auto-approved by the user
|
||||
if node_id in execution_context.auto_approved_node_ids:
|
||||
logger.info(
|
||||
f"HITL block skipping review for node {node_exec_id} - "
|
||||
f"node {node_id} is auto-approved"
|
||||
)
|
||||
yield "approved_data", input_data.data
|
||||
yield "review_message", "Auto-approved (user approved all future actions for this block)"
|
||||
return
|
||||
|
||||
if not execution_context.human_in_the_loop_safe_mode:
|
||||
if not execution_context.safe_mode:
|
||||
logger.info(
|
||||
f"HITL block skipping review for node {node_exec_id} - safe mode disabled"
|
||||
)
|
||||
@@ -126,7 +115,6 @@ class HumanInTheLoopBlock(Block):
|
||||
decision = await self.handle_review_decision(
|
||||
input_data=input_data.data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
|
||||
@@ -242,7 +242,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -343,7 +343,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -409,7 +409,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -471,7 +471,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -535,7 +535,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -658,7 +658,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -730,7 +730,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -786,7 +786,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
outputs = {}
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
@@ -905,7 +905,7 @@ async def test_smart_decision_maker_agent_mode():
|
||||
# Create a mock execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=False,
|
||||
safe_mode=False,
|
||||
)
|
||||
|
||||
# Create a mock execution processor for agent mode tests
|
||||
@@ -1027,7 +1027,7 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
|
||||
# Create execution context
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a mock execution processor for tests
|
||||
|
||||
|
||||
@@ -386,7 +386,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
outputs = {}
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
mock_execution_context = ExecutionContext(human_in_the_loop_safe_mode=False)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
mock_execution_processor = MagicMock()
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
@@ -609,9 +609,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
outputs = {}
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
mock_execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=False
|
||||
)
|
||||
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||
|
||||
# Create a proper mock execution processor for agent mode
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -474,7 +474,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
self.is_sensitive_action: bool = False
|
||||
self.requires_human_review: bool = False
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -622,7 +622,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
input_data: BlockInput,
|
||||
*,
|
||||
user_id: str,
|
||||
node_id: str,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
@@ -638,9 +637,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
- should_pause: True if execution should be paused for review
|
||||
- input_data_to_use: The input data to use (may be modified by reviewer)
|
||||
"""
|
||||
if not (
|
||||
self.is_sensitive_action and execution_context.sensitive_action_safe_mode
|
||||
):
|
||||
# Skip review if not required or safe mode is disabled
|
||||
if not self.requires_human_review or not execution_context.safe_mode:
|
||||
return False, input_data
|
||||
|
||||
from backend.blocks.helpers.review import HITLReviewHelper
|
||||
@@ -649,7 +647,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
decision = await HITLReviewHelper.handle_review_decision(
|
||||
input_data=input_data,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
|
||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
# Add public schema to search_path for pgvector type access
|
||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||
# Extract the schema from DATABASE_URL or default to 'public' (matching get_database_schema())
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
url_params = dict(parse_qsl(parsed_url.query))
|
||||
db_schema = url_params.get("schema", "public")
|
||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||
search_path_schemas = list(
|
||||
dict.fromkeys([db_schema, "public"])
|
||||
) # Preserves order, removes duplicates
|
||||
search_path = ",".join(search_path_schemas)
|
||||
# This allows using ::vector without schema qualification
|
||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
@@ -113,48 +127,38 @@ async def _raw_with_schema(
|
||||
*args,
|
||||
execute: bool = False,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> list[dict] | int:
|
||||
"""Internal: Execute raw SQL with proper schema handling.
|
||||
|
||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||
|
||||
Supports placeholders:
|
||||
- {schema_prefix}: Table/type prefix (e.g., "platform".)
|
||||
- {schema}: Raw schema name for application tables (e.g., platform)
|
||||
- {pgvector_schema}: Schema where pgvector is installed (defaults to "public")
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix}, {schema}, and/or {pgvector_schema} placeholders
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||
client: Optional Prisma client for transactions (only used when execute=True).
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
- list[dict] if execute=False (query results)
|
||||
- int if execute=True (number of affected rows)
|
||||
|
||||
Example with vector type:
|
||||
await execute_raw_with_schema(
|
||||
'INSERT INTO {schema_prefix}"Embedding" (vec) VALUES ($1::{pgvector_schema}.vector)',
|
||||
embedding_data
|
||||
)
|
||||
"""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
# pgvector extension is typically installed in "public" schema
|
||||
# On Supabase it may be in "extensions" but "public" is the common default
|
||||
pgvector_schema = "public"
|
||||
|
||||
formatted_query = query_template.format(
|
||||
schema_prefix=schema_prefix,
|
||||
schema=schema,
|
||||
pgvector_schema=pgvector_schema,
|
||||
)
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
# Set search_path to include public schema if requested
|
||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||
# This is idempotent and safe to call multiple times
|
||||
if set_public_search_path:
|
||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
@@ -163,12 +167,16 @@ async def _raw_with_schema(
|
||||
return result
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
async def query_raw_with_schema(
|
||||
query_template: str, *args, set_public_search_path: bool = False
|
||||
) -> list[dict]:
|
||||
"""Execute raw SQL SELECT query with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} and/or {schema} placeholders
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
@@ -179,20 +187,23 @@ async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
user_id
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=False) # type: ignore
|
||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
async def execute_raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> int:
|
||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} and/or {schema} placeholders
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
client: Optional Prisma client for transactions
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
Number of affected rows
|
||||
@@ -204,7 +215,7 @@ async def execute_raw_with_schema(
|
||||
client=tx # Optional transaction client
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client) # type: ignore
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
|
||||
@@ -81,12 +81,10 @@ class ExecutionContext(BaseModel):
|
||||
This includes information needed by blocks, sub-graphs, and execution management.
|
||||
"""
|
||||
|
||||
human_in_the_loop_safe_mode: bool = True
|
||||
sensitive_action_safe_mode: bool = False
|
||||
safe_mode: bool = True
|
||||
user_timezone: str = "UTC"
|
||||
root_execution_id: Optional[str] = None
|
||||
parent_execution_id: Optional[str] = None
|
||||
auto_approved_node_ids: set[str] = Field(default_factory=set)
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
@@ -62,23 +62,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphSettings(BaseModel):
|
||||
human_in_the_loop_safe_mode: bool = True
|
||||
sensitive_action_safe_mode: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_graph(
|
||||
cls,
|
||||
graph: "GraphModel",
|
||||
hitl_safe_mode: bool | None = None,
|
||||
sensitive_action_safe_mode: bool = False,
|
||||
) -> "GraphSettings":
|
||||
# Default to True if not explicitly set
|
||||
if hitl_safe_mode is None:
|
||||
hitl_safe_mode = True
|
||||
return cls(
|
||||
human_in_the_loop_safe_mode=hitl_safe_mode,
|
||||
sensitive_action_safe_mode=sensitive_action_safe_mode,
|
||||
)
|
||||
human_in_the_loop_safe_mode: bool | None = None
|
||||
|
||||
|
||||
class Link(BaseDbModel):
|
||||
@@ -260,14 +244,10 @@ class BaseGraph(BaseDbModel):
|
||||
return any(
|
||||
node.block_id
|
||||
for node in self.nodes
|
||||
if node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def has_sensitive_action(self) -> bool:
|
||||
return any(
|
||||
node.block_id for node in self.nodes if node.block.is_sensitive_action
|
||||
if (
|
||||
node.block.block_type == BlockType.HUMAN_IN_THE_LOOP
|
||||
or node.block.requires_human_review
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -38,7 +38,6 @@ async def get_or_create_human_review(
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_id: str,
|
||||
input_data: SafeJsonData,
|
||||
message: str,
|
||||
editable: bool,
|
||||
@@ -54,7 +53,6 @@ async def get_or_create_human_review(
|
||||
graph_exec_id: ID of the graph execution
|
||||
graph_id: ID of the graph template
|
||||
graph_version: Version of the graph template
|
||||
node_id: ID of the node in the graph definition
|
||||
input_data: The data to be reviewed
|
||||
message: Instructions for the reviewer
|
||||
editable: Whether the data can be edited
|
||||
@@ -75,7 +73,6 @@ async def get_or_create_human_review(
|
||||
"graphExecId": graph_exec_id,
|
||||
"graphId": graph_id,
|
||||
"graphVersion": graph_version,
|
||||
"nodeId": node_id,
|
||||
"payload": SafeJson(input_data),
|
||||
"instructions": message,
|
||||
"editable": editable,
|
||||
|
||||
@@ -23,7 +23,6 @@ def sample_db_review():
|
||||
mock_review.graphExecId = "test_graph_exec_456"
|
||||
mock_review.graphId = "test_graph_789"
|
||||
mock_review.graphVersion = 1
|
||||
mock_review.nodeId = "node_def_123"
|
||||
mock_review.payload = {"data": "test payload"}
|
||||
mock_review.instructions = "Please review"
|
||||
mock_review.editable = True
|
||||
@@ -56,7 +55,6 @@ async def test_get_or_create_human_review_new(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
input_data={"data": "test payload"},
|
||||
message="Please review",
|
||||
editable=True,
|
||||
@@ -86,7 +84,6 @@ async def test_get_or_create_human_review_approved(
|
||||
graph_exec_id="test_graph_exec_456",
|
||||
graph_id="test_graph_789",
|
||||
graph_version=1,
|
||||
node_id="node_def_123",
|
||||
input_data={"data": "test payload"},
|
||||
message="Please review",
|
||||
editable=True,
|
||||
@@ -186,7 +183,6 @@ async def test_process_all_reviews_for_execution_success(
|
||||
updated_review.graphExecId = "test_graph_exec_456"
|
||||
updated_review.graphId = "test_graph_789"
|
||||
updated_review.graphVersion = 1
|
||||
updated_review.nodeId = "node_def_123"
|
||||
updated_review.payload = {"data": "modified"}
|
||||
updated_review.instructions = "Please review"
|
||||
updated_review.editable = True
|
||||
@@ -276,7 +272,6 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
||||
second_review.graphExecId = "test_graph_exec_456"
|
||||
second_review.graphId = "test_graph_789"
|
||||
second_review.graphVersion = 1
|
||||
second_review.nodeId = "node_def_456"
|
||||
second_review.payload = {"data": "original"}
|
||||
second_review.instructions = "Second review"
|
||||
second_review.editable = True
|
||||
@@ -301,7 +296,6 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
||||
approved_review.graphExecId = "test_graph_exec_456"
|
||||
approved_review.graphId = "test_graph_789"
|
||||
approved_review.graphVersion = 1
|
||||
approved_review.nodeId = "node_def_123"
|
||||
approved_review.payload = {"data": "modified"}
|
||||
approved_review.instructions = "Please review"
|
||||
approved_review.editable = True
|
||||
@@ -319,7 +313,6 @@ async def test_process_all_reviews_mixed_approval_rejection(
|
||||
rejected_review.graphExecId = "test_graph_exec_456"
|
||||
rejected_review.graphId = "test_graph_789"
|
||||
rejected_review.graphVersion = 1
|
||||
rejected_review.nodeId = "node_def_456"
|
||||
rejected_review.payload = {"data": "original"}
|
||||
rejected_review.instructions = "Please review"
|
||||
rejected_review.editable = True
|
||||
|
||||
@@ -309,7 +309,7 @@ def ensure_embeddings_coverage():
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=100)
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
|
||||
@@ -873,8 +873,11 @@ async def add_graph_execution(
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
safe_mode=(
|
||||
settings.human_in_the_loop_safe_mode
|
||||
if settings.human_in_the_loop_safe_mode is not None
|
||||
else True
|
||||
),
|
||||
user_timezone=(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
),
|
||||
|
||||
@@ -386,7 +386,6 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
mock_settings.sensitive_action_safe_mode = False
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
@@ -652,7 +651,6 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_user.timezone = "UTC"
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.human_in_the_loop_safe_mode = True
|
||||
mock_settings.sensitive_action_safe_mode = False
|
||||
|
||||
mock_udb.get_user_by_id = mocker.AsyncMock(return_value=mock_user)
|
||||
mock_gdb.get_graph_settings = mocker.AsyncMock(return_value=mock_settings)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "PendingHumanReview" ADD COLUMN "nodeId" TEXT NOT NULL DEFAULT '';
|
||||
@@ -573,7 +573,6 @@ model PendingHumanReview {
|
||||
graphExecId String
|
||||
graphId String
|
||||
graphVersion Int
|
||||
nodeId String // The node ID in the graph definition (for auto-approval tracking)
|
||||
payload Json // The actual payload data to be reviewed
|
||||
instructions String? // Instructions/message for the reviewer
|
||||
editable Boolean @default(true) // Whether the reviewer can edit the data
|
||||
|
||||
@@ -366,12 +366,12 @@ def generate_block_markdown(
|
||||
lines.append("")
|
||||
|
||||
# What it is (full description)
|
||||
lines.append("### What it is")
|
||||
lines.append(f"### What it is")
|
||||
lines.append(block.description or "No description available.")
|
||||
lines.append("")
|
||||
|
||||
# How it works (manual section)
|
||||
lines.append("### How it works")
|
||||
lines.append(f"### How it works")
|
||||
how_it_works = manual_content.get(
|
||||
"how_it_works", "_Add technical explanation here._"
|
||||
)
|
||||
@@ -383,7 +383,7 @@ def generate_block_markdown(
|
||||
# Inputs table (auto-generated)
|
||||
visible_inputs = [f for f in block.inputs if not f.hidden]
|
||||
if visible_inputs:
|
||||
lines.append("### Inputs")
|
||||
lines.append(f"### Inputs")
|
||||
lines.append("")
|
||||
lines.append("| Input | Description | Type | Required |")
|
||||
lines.append("|-------|-------------|------|----------|")
|
||||
@@ -400,7 +400,7 @@ def generate_block_markdown(
|
||||
# Outputs table (auto-generated)
|
||||
visible_outputs = [f for f in block.outputs if not f.hidden]
|
||||
if visible_outputs:
|
||||
lines.append("### Outputs")
|
||||
lines.append(f"### Outputs")
|
||||
lines.append("")
|
||||
lines.append("| Output | Description | Type |")
|
||||
lines.append("|--------|-------------|------|")
|
||||
@@ -414,7 +414,7 @@ def generate_block_markdown(
|
||||
lines.append("")
|
||||
|
||||
# Possible use case (manual section)
|
||||
lines.append("### Possible use case")
|
||||
lines.append(f"### Possible use case")
|
||||
use_case = manual_content.get("use_case", "_Add practical use case examples here._")
|
||||
lines.append("<!-- MANUAL: use_case -->")
|
||||
lines.append(use_case)
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
"forked_from_version": null,
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"id": "graph-123",
|
||||
"input_schema": {
|
||||
"properties": {},
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
"forked_from_version": null,
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"id": "graph-123",
|
||||
"input_schema": {
|
||||
"properties": {},
|
||||
|
||||
@@ -27,8 +27,6 @@
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": true,
|
||||
@@ -36,8 +34,7 @@
|
||||
"is_favorite": false,
|
||||
"recommended_schedule_cron": null,
|
||||
"settings": {
|
||||
"human_in_the_loop_safe_mode": true,
|
||||
"sensitive_action_safe_mode": false
|
||||
"human_in_the_loop_safe_mode": null
|
||||
},
|
||||
"marketplace_listing": null
|
||||
},
|
||||
@@ -68,8 +65,6 @@
|
||||
"properties": {}
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"can_access_graph": false,
|
||||
@@ -77,8 +72,7 @@
|
||||
"is_favorite": false,
|
||||
"recommended_schedule_cron": null,
|
||||
"settings": {
|
||||
"human_in_the_loop_safe_mode": true,
|
||||
"sensitive_action_safe_mode": false
|
||||
"human_in_the_loop_safe_mode": null
|
||||
},
|
||||
"marketplace_listing": null
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ import {
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { CircleNotchIcon, PlayIcon, StopIcon } from "@phosphor-icons/react";
|
||||
import { PlayIcon, StopIcon } from "@phosphor-icons/react";
|
||||
import { useShallow } from "zustand/react/shallow";
|
||||
import { RunInputDialog } from "../RunInputDialog/RunInputDialog";
|
||||
import { useRunGraph } from "./useRunGraph";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
const {
|
||||
@@ -25,31 +24,6 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
useShallow((state) => state.isGraphRunning),
|
||||
);
|
||||
|
||||
const isLoading = isExecutingGraph || isTerminatingGraph || isSaving;
|
||||
|
||||
// Determine which icon to show with proper animation
|
||||
const renderIcon = () => {
|
||||
const iconClass = cn(
|
||||
"size-4 transition-transform duration-200 ease-out",
|
||||
!isLoading && "group-hover:scale-110",
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<CircleNotchIcon
|
||||
className={cn(iconClass, "animate-spin")}
|
||||
weight="bold"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isGraphRunning) {
|
||||
return <StopIcon className={iconClass} weight="fill" />;
|
||||
}
|
||||
|
||||
return <PlayIcon className={iconClass} weight="fill" />;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Tooltip>
|
||||
@@ -59,18 +33,18 @@ export const RunGraph = ({ flowID }: { flowID: string | null }) => {
|
||||
variant={isGraphRunning ? "destructive" : "primary"}
|
||||
data-id={isGraphRunning ? "stop-graph-button" : "run-graph-button"}
|
||||
onClick={isGraphRunning ? handleStopGraph : handleRunGraph}
|
||||
disabled={!flowID || isLoading}
|
||||
className="group"
|
||||
disabled={!flowID || isExecutingGraph || isTerminatingGraph}
|
||||
loading={isExecutingGraph || isTerminatingGraph || isSaving}
|
||||
>
|
||||
{renderIcon()}
|
||||
{!isGraphRunning ? (
|
||||
<PlayIcon className="size-4" />
|
||||
) : (
|
||||
<StopIcon className="size-4" />
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{isLoading
|
||||
? "Processing..."
|
||||
: isGraphRunning
|
||||
? "Stop agent"
|
||||
: "Run agent"}
|
||||
{isGraphRunning ? "Stop agent" : "Run agent"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<RunInputDialog
|
||||
|
||||
@@ -10,7 +10,6 @@ import { useRunInputDialog } from "./useRunInputDialog";
|
||||
import { CronSchedulerDialog } from "../CronSchedulerDialog/CronSchedulerDialog";
|
||||
import { useTutorialStore } from "@/app/(platform)/build/stores/tutorialStore";
|
||||
import { useEffect } from "react";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
|
||||
export const RunInputDialog = ({
|
||||
isOpen,
|
||||
@@ -24,17 +23,19 @@ export const RunInputDialog = ({
|
||||
const hasInputs = useGraphStore((state) => state.hasInputs);
|
||||
const hasCredentials = useGraphStore((state) => state.hasCredentials);
|
||||
const inputSchema = useGraphStore((state) => state.inputSchema);
|
||||
const credentialsSchema = useGraphStore(
|
||||
(state) => state.credentialsInputSchema,
|
||||
);
|
||||
|
||||
const {
|
||||
credentialFields,
|
||||
requiredCredentials,
|
||||
credentialsUiSchema,
|
||||
handleManualRun,
|
||||
handleInputChange,
|
||||
openCronSchedulerDialog,
|
||||
setOpenCronSchedulerDialog,
|
||||
inputValues,
|
||||
credentialValues,
|
||||
handleCredentialFieldChange,
|
||||
handleCredentialChange,
|
||||
isExecutingGraph,
|
||||
} = useRunInputDialog({ setIsOpen });
|
||||
|
||||
@@ -61,67 +62,67 @@ export const RunInputDialog = ({
|
||||
isOpen,
|
||||
set: setIsOpen,
|
||||
}}
|
||||
styling={{ maxWidth: "700px", minWidth: "700px" }}
|
||||
styling={{ maxWidth: "600px", minWidth: "600px" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div
|
||||
className="grid grid-cols-[1fr_auto] gap-10 p-1"
|
||||
data-id="run-input-dialog-content"
|
||||
>
|
||||
<div className="space-y-6">
|
||||
{/* Credentials Section */}
|
||||
{hasCredentials() && credentialFields.length > 0 && (
|
||||
<div data-id="run-input-credentials-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Credentials
|
||||
</Text>
|
||||
</div>
|
||||
<div className="px-2" data-id="run-input-credentials-form">
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={credentialValues}
|
||||
inputValues={inputValues}
|
||||
onCredentialChange={handleCredentialFieldChange}
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-6 p-1" data-id="run-input-dialog-content">
|
||||
{/* Credentials Section */}
|
||||
{hasCredentials() && (
|
||||
<div data-id="run-input-credentials-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Credentials
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Inputs Section */}
|
||||
{hasInputs() && (
|
||||
<div data-id="run-input-inputs-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Inputs
|
||||
</Text>
|
||||
</div>
|
||||
<div data-id="run-input-inputs-form">
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
uiSchema={uiSchema}
|
||||
initialValues={{}}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="px-2" data-id="run-input-credentials-form">
|
||||
<FormRenderer
|
||||
jsonSchema={credentialsSchema as RJSFSchema}
|
||||
handleChange={(v) => handleCredentialChange(v.formData)}
|
||||
uiSchema={credentialsUiSchema}
|
||||
initialValues={{}}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
showOptionalToggle: false,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Inputs Section */}
|
||||
{hasInputs() && (
|
||||
<div data-id="run-input-inputs-section">
|
||||
<div className="mb-4">
|
||||
<Text variant="h4" className="text-gray-900">
|
||||
Inputs
|
||||
</Text>
|
||||
</div>
|
||||
<div data-id="run-input-inputs-form">
|
||||
<FormRenderer
|
||||
jsonSchema={inputSchema as RJSFSchema}
|
||||
handleChange={(v) => handleInputChange(v.formData)}
|
||||
uiSchema={uiSchema}
|
||||
initialValues={{}}
|
||||
formContext={{
|
||||
showHandles: false,
|
||||
size: "large",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Action Button */}
|
||||
<div
|
||||
className="flex flex-col items-end justify-start"
|
||||
className="flex justify-end pt-2"
|
||||
data-id="run-input-actions-section"
|
||||
>
|
||||
{purpose === "run" && (
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className="group h-fit min-w-0 gap-2 px-10"
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={handleManualRun}
|
||||
loading={isExecutingGraph}
|
||||
data-id="run-input-manual-run-button"
|
||||
@@ -136,7 +137,7 @@ export const RunInputDialog = ({
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className="group h-fit min-w-0 gap-2 px-10"
|
||||
className="group h-fit min-w-0 gap-2"
|
||||
onClick={() => setOpenCronSchedulerDialog(true)}
|
||||
data-id="run-input-schedule-button"
|
||||
>
|
||||
|
||||
@@ -7,11 +7,12 @@ import {
|
||||
GraphExecutionMeta,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { uiSchema } from "../../../FlowEditor/nodes/uiSchema";
|
||||
import { isCredentialFieldSchema } from "@/components/renderers/InputRenderer/custom/CredentialField/helpers";
|
||||
import { useNodeStore } from "@/app/(platform)/build/stores/nodeStore";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import type { CredentialField } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers";
|
||||
|
||||
export const useRunInputDialog = ({
|
||||
setIsOpen,
|
||||
@@ -119,32 +120,27 @@ export const useRunInputDialog = ({
|
||||
},
|
||||
});
|
||||
|
||||
// Convert credentials schema to credential fields array for CredentialsGroupedView
|
||||
const credentialFields: CredentialField[] = useMemo(() => {
|
||||
if (!credentialsSchema?.properties) return [];
|
||||
return Object.entries(credentialsSchema.properties);
|
||||
}, [credentialsSchema]);
|
||||
// We are rendering the credentials field differently compared to other fields.
|
||||
// In the node, we have the field name as "credential" - so our library catches it and renders it differently.
|
||||
// But here we have a different name, something like `Firecrawl credentials`, so here we are telling the library that this field is a credential field type.
|
||||
|
||||
// Get required credentials as a Set
|
||||
const requiredCredentials = useMemo(() => {
|
||||
return new Set<string>(credentialsSchema?.required || []);
|
||||
}, [credentialsSchema]);
|
||||
const credentialsUiSchema = useMemo(() => {
|
||||
const dynamicUiSchema: any = { ...uiSchema };
|
||||
|
||||
// Handler for individual credential changes
|
||||
const handleCredentialFieldChange = useCallback(
|
||||
(key: string, value?: CredentialsMetaInput) => {
|
||||
setCredentialValues((prev) => {
|
||||
if (value) {
|
||||
return { ...prev, [key]: value };
|
||||
} else {
|
||||
const next = { ...prev };
|
||||
delete next[key];
|
||||
return next;
|
||||
if (credentialsSchema?.properties) {
|
||||
Object.keys(credentialsSchema.properties).forEach((fieldName) => {
|
||||
const fieldSchema = credentialsSchema.properties[fieldName];
|
||||
if (isCredentialFieldSchema(fieldSchema)) {
|
||||
dynamicUiSchema[fieldName] = {
|
||||
...dynamicUiSchema[fieldName],
|
||||
"ui:field": "custom/credential_field",
|
||||
};
|
||||
}
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
}
|
||||
|
||||
return dynamicUiSchema;
|
||||
}, [credentialsSchema]);
|
||||
|
||||
const handleManualRun = async () => {
|
||||
// Filter out incomplete credentials (those without a valid id)
|
||||
@@ -177,14 +173,12 @@ export const useRunInputDialog = ({
|
||||
};
|
||||
|
||||
return {
|
||||
credentialFields,
|
||||
requiredCredentials,
|
||||
credentialsUiSchema,
|
||||
inputValues,
|
||||
credentialValues,
|
||||
isExecutingGraph,
|
||||
handleInputChange,
|
||||
handleCredentialChange,
|
||||
handleCredentialFieldChange,
|
||||
handleManualRun,
|
||||
openCronSchedulerDialog,
|
||||
setOpenCronSchedulerDialog,
|
||||
|
||||
@@ -18,110 +18,69 @@ interface Props {
|
||||
fullWidth?: boolean;
|
||||
}
|
||||
|
||||
interface SafeModeButtonProps {
|
||||
isEnabled: boolean;
|
||||
label: string;
|
||||
tooltipEnabled: string;
|
||||
tooltipDisabled: string;
|
||||
onToggle: () => void;
|
||||
isPending: boolean;
|
||||
fullWidth?: boolean;
|
||||
}
|
||||
|
||||
function SafeModeButton({
|
||||
isEnabled,
|
||||
label,
|
||||
tooltipEnabled,
|
||||
tooltipDisabled,
|
||||
onToggle,
|
||||
isPending,
|
||||
fullWidth = false,
|
||||
}: SafeModeButtonProps) {
|
||||
return (
|
||||
<Tooltip delayDuration={100}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant={isEnabled ? "primary" : "outline"}
|
||||
size="small"
|
||||
onClick={onToggle}
|
||||
disabled={isPending}
|
||||
className={cn("justify-start", fullWidth ? "w-full" : "")}
|
||||
>
|
||||
{isEnabled ? (
|
||||
<>
|
||||
<ShieldCheckIcon weight="bold" size={16} />
|
||||
<Text variant="body" className="text-zinc-200">
|
||||
{label}: ON
|
||||
</Text>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ShieldIcon weight="bold" size={16} />
|
||||
<Text variant="body" className="text-zinc-600">
|
||||
{label}: OFF
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<div className="text-center">
|
||||
<div className="font-medium">
|
||||
{label}: {isEnabled ? "ON" : "OFF"}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-muted-foreground">
|
||||
{isEnabled ? tooltipEnabled : tooltipDisabled}
|
||||
</div>
|
||||
</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
export function FloatingSafeModeToggle({
|
||||
graph,
|
||||
className,
|
||||
fullWidth = false,
|
||||
}: Props) {
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
handleSensitiveActionToggle,
|
||||
currentSafeMode,
|
||||
isPending,
|
||||
shouldShowToggle,
|
||||
isStateUndetermined,
|
||||
handleToggle,
|
||||
} = useAgentSafeMode(graph);
|
||||
|
||||
if (!shouldShowToggle || isPending) {
|
||||
if (!shouldShowToggle || isStateUndetermined || isPending) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("fixed z-50 flex flex-col gap-2", className)}>
|
||||
{showHITLToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human in the loop block approval"
|
||||
tooltipEnabled="The agent will pause at human-in-the-loop blocks and wait for your approval"
|
||||
tooltipDisabled="Human in the loop blocks will proceed automatically"
|
||||
onToggle={handleHITLToggle}
|
||||
isPending={isPending}
|
||||
fullWidth={fullWidth}
|
||||
/>
|
||||
)}
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions blocks approval"
|
||||
tooltipEnabled="The agent will pause at sensitive action blocks and wait for your approval"
|
||||
tooltipDisabled="Sensitive action blocks will proceed automatically"
|
||||
onToggle={handleSensitiveActionToggle}
|
||||
isPending={isPending}
|
||||
fullWidth={fullWidth}
|
||||
/>
|
||||
)}
|
||||
<div className={cn("fixed z-50", className)}>
|
||||
<Tooltip delayDuration={100}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant={currentSafeMode! ? "primary" : "outline"}
|
||||
key={graph.id}
|
||||
size="small"
|
||||
title={
|
||||
currentSafeMode!
|
||||
? "Safe Mode: ON. Human in the loop blocks require manual review"
|
||||
: "Safe Mode: OFF. Human in the loop blocks proceed automatically"
|
||||
}
|
||||
onClick={handleToggle}
|
||||
className={cn(fullWidth ? "w-full" : "")}
|
||||
>
|
||||
{currentSafeMode! ? (
|
||||
<>
|
||||
<ShieldCheckIcon weight="bold" size={16} />
|
||||
<Text variant="body" className="text-zinc-200">
|
||||
Safe Mode: ON
|
||||
</Text>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ShieldIcon weight="bold" size={16} />
|
||||
<Text variant="body" className="text-zinc-600">
|
||||
Safe Mode: OFF
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<div className="text-center">
|
||||
<div className="font-medium">
|
||||
Safe Mode: {currentSafeMode! ? "ON" : "OFF"}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-muted-foreground">
|
||||
{currentSafeMode!
|
||||
? "Human in the loop blocks require manual review"
|
||||
: "Human in the loop blocks proceed automatically"}
|
||||
</div>
|
||||
</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -53,14 +53,14 @@ export const CustomControls = memo(
|
||||
const controls = [
|
||||
{
|
||||
id: "zoom-in-button",
|
||||
icon: <PlusIcon className="size-3.5 text-zinc-600" />,
|
||||
icon: <PlusIcon className="size-4" />,
|
||||
label: "Zoom In",
|
||||
onClick: () => zoomIn(),
|
||||
className: "h-10 w-10 border-none",
|
||||
},
|
||||
{
|
||||
id: "zoom-out-button",
|
||||
icon: <MinusIcon className="size-3.5 text-zinc-600" />,
|
||||
icon: <MinusIcon className="size-4" />,
|
||||
label: "Zoom Out",
|
||||
onClick: () => zoomOut(),
|
||||
className: "h-10 w-10 border-none",
|
||||
@@ -68,9 +68,9 @@ export const CustomControls = memo(
|
||||
{
|
||||
id: "tutorial-button",
|
||||
icon: isTutorialLoading ? (
|
||||
<CircleNotchIcon className="size-3.5 animate-spin text-zinc-600" />
|
||||
<CircleNotchIcon className="size-4 animate-spin" />
|
||||
) : (
|
||||
<ChalkboardIcon className="size-3.5 text-zinc-600" />
|
||||
<ChalkboardIcon className="size-4" />
|
||||
),
|
||||
label: isTutorialLoading ? "Loading Tutorial..." : "Start Tutorial",
|
||||
onClick: handleTutorialClick,
|
||||
@@ -79,7 +79,7 @@ export const CustomControls = memo(
|
||||
},
|
||||
{
|
||||
id: "fit-view-button",
|
||||
icon: <FrameCornersIcon className="size-3.5 text-zinc-600" />,
|
||||
icon: <FrameCornersIcon className="size-4" />,
|
||||
label: "Fit View",
|
||||
onClick: () => fitView({ padding: 0.2, duration: 800, maxZoom: 1 }),
|
||||
className: "h-10 w-10 border-none",
|
||||
@@ -87,9 +87,9 @@ export const CustomControls = memo(
|
||||
{
|
||||
id: "lock-button",
|
||||
icon: !isLocked ? (
|
||||
<LockOpenIcon className="size-3.5 text-zinc-600" />
|
||||
<LockOpenIcon className="size-4" />
|
||||
) : (
|
||||
<LockIcon className="size-3.5 text-zinc-600" />
|
||||
<LockIcon className="size-4" />
|
||||
),
|
||||
label: "Toggle Lock",
|
||||
onClick: () => setIsLocked(!isLocked),
|
||||
|
||||
@@ -139,6 +139,14 @@ export const useFlow = () => {
|
||||
useNodeStore.getState().setNodes([]);
|
||||
useNodeStore.getState().clearResolutionState();
|
||||
addNodes(customNodes);
|
||||
|
||||
// Sync hardcoded values with handle IDs.
|
||||
// If a key–value field has a key without a value, the backend omits it from hardcoded values.
|
||||
// But if a handleId exists for that key, it causes inconsistency.
|
||||
// This ensures hardcoded values stay in sync with handle IDs.
|
||||
customNodes.forEach((node) => {
|
||||
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
|
||||
});
|
||||
}
|
||||
}, [customNodes, addNodes]);
|
||||
|
||||
@@ -150,14 +158,6 @@ export const useFlow = () => {
|
||||
}
|
||||
}, [graph?.links, addLinks]);
|
||||
|
||||
useEffect(() => {
|
||||
if (customNodes.length > 0 && graph?.links) {
|
||||
customNodes.forEach((node) => {
|
||||
useNodeStore.getState().syncHardcodedValuesWithHandleIds(node.id);
|
||||
});
|
||||
}
|
||||
}, [customNodes, graph?.links]);
|
||||
|
||||
// update node execution status in nodes
|
||||
useEffect(() => {
|
||||
if (
|
||||
|
||||
@@ -19,8 +19,6 @@ export type CustomEdgeData = {
|
||||
beadUp?: number;
|
||||
beadDown?: number;
|
||||
beadData?: Map<string, NodeExecutionResult["status"]>;
|
||||
edgeColorClass?: string;
|
||||
edgeHexColor?: string;
|
||||
};
|
||||
|
||||
export type CustomEdge = XYEdge<CustomEdgeData, "custom">;
|
||||
@@ -38,6 +36,7 @@ const CustomEdge = ({
|
||||
selected,
|
||||
}: EdgeProps<CustomEdge>) => {
|
||||
const removeConnection = useEdgeStore((state) => state.removeEdge);
|
||||
// Subscribe to the brokenEdgeIDs map and check if this edge is broken across any node
|
||||
const isBroken = useNodeStore((state) => state.isEdgeBroken(id));
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -53,7 +52,6 @@ const CustomEdge = ({
|
||||
const isStatic = data?.isStatic ?? false;
|
||||
const beadUp = data?.beadUp ?? 0;
|
||||
const beadDown = data?.beadDown ?? 0;
|
||||
const edgeColorClass = data?.edgeColorClass;
|
||||
|
||||
const handleRemoveEdge = () => {
|
||||
removeConnection(id);
|
||||
@@ -72,9 +70,7 @@ const CustomEdge = ({
|
||||
? "!stroke-red-500 !stroke-[2px] [stroke-dasharray:4]"
|
||||
: selected
|
||||
? "stroke-zinc-800"
|
||||
: edgeColorClass
|
||||
? cn(edgeColorClass, "opacity-70 hover:opacity-100")
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
: "stroke-zinc-500/50 hover:stroke-zinc-500",
|
||||
)}
|
||||
/>
|
||||
<JSBeads
|
||||
|
||||
@@ -8,7 +8,6 @@ import { useCallback } from "react";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useHistoryStore } from "../../../stores/historyStore";
|
||||
import { CustomEdge } from "./CustomEdge";
|
||||
import { getEdgeColorFromOutputType } from "../nodes/helpers";
|
||||
|
||||
export const useCustomEdge = () => {
|
||||
const edges = useEdgeStore((s) => s.edges);
|
||||
@@ -35,13 +34,8 @@ export const useCustomEdge = () => {
|
||||
if (exists) return;
|
||||
|
||||
const nodes = useNodeStore.getState().nodes;
|
||||
const sourceNode = nodes.find((n) => n.id === conn.source);
|
||||
const isStatic = sourceNode?.data?.staticOutput;
|
||||
|
||||
const { colorClass, hexColor } = getEdgeColorFromOutputType(
|
||||
sourceNode?.data?.outputSchema,
|
||||
conn.sourceHandle,
|
||||
);
|
||||
const isStatic = nodes.find((n) => n.id === conn.source)?.data
|
||||
?.staticOutput;
|
||||
|
||||
addEdge({
|
||||
source: conn.source,
|
||||
@@ -50,8 +44,6 @@ export const useCustomEdge = () => {
|
||||
targetHandle: conn.targetHandle,
|
||||
data: {
|
||||
isStatic,
|
||||
edgeColorClass: colorClass,
|
||||
edgeHexColor: hexColor,
|
||||
},
|
||||
});
|
||||
},
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import {
|
||||
Accordion,
|
||||
AccordionContent,
|
||||
AccordionItem,
|
||||
AccordionTrigger,
|
||||
} from "@/components/molecules/Accordion/Accordion";
|
||||
import { beautifyString, cn } from "@/lib/utils";
|
||||
import { CopyIcon, CheckIcon } from "@phosphor-icons/react";
|
||||
import { CaretDownIcon, CopyIcon, CheckIcon } from "@phosphor-icons/react";
|
||||
import { NodeDataViewer } from "./components/NodeDataViewer/NodeDataViewer";
|
||||
import { ContentRenderer } from "./components/ContentRenderer";
|
||||
import { useNodeOutput } from "./useNodeOutput";
|
||||
import { ViewMoreData } from "./components/ViewMoreData";
|
||||
|
||||
export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
const { outputData, copiedKey, handleCopy, executionResultId, inputData } =
|
||||
useNodeOutput(nodeId);
|
||||
const {
|
||||
outputData,
|
||||
isExpanded,
|
||||
setIsExpanded,
|
||||
copiedKey,
|
||||
handleCopy,
|
||||
executionResultId,
|
||||
inputData,
|
||||
} = useNodeOutput(nodeId);
|
||||
|
||||
if (Object.keys(outputData).length === 0) {
|
||||
return null;
|
||||
@@ -24,117 +25,122 @@ export const NodeDataRenderer = ({ nodeId }: { nodeId: string }) => {
|
||||
return (
|
||||
<div
|
||||
data-tutorial-id={`node-output`}
|
||||
className="rounded-b-xl border-t border-zinc-200 px-4 py-2"
|
||||
className="flex flex-col gap-3 rounded-b-xl border-t border-zinc-200 px-4 py-4"
|
||||
>
|
||||
<Accordion type="single" collapsible defaultValue="node-output">
|
||||
<AccordionItem value="node-output" className="border-none">
|
||||
<AccordionTrigger className="py-2 hover:no-underline">
|
||||
<Text
|
||||
variant="body-medium"
|
||||
className="!font-semibold text-slate-700"
|
||||
>
|
||||
Node Output
|
||||
</Text>
|
||||
</AccordionTrigger>
|
||||
<AccordionContent className="pt-2">
|
||||
<div className="flex max-w-[350px] flex-col gap-4">
|
||||
<div className="space-y-2">
|
||||
<Text variant="small-medium">Input</Text>
|
||||
<div className="flex items-center justify-between">
|
||||
<Text variant="body-medium" className="!font-semibold text-slate-700">
|
||||
Node Output
|
||||
</Text>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="small"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="h-fit min-w-0 p-1 text-slate-600 hover:text-slate-900"
|
||||
>
|
||||
<CaretDownIcon
|
||||
size={16}
|
||||
weight="bold"
|
||||
className={`transition-transform ${isExpanded ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<ContentRenderer value={inputData} shortContent={false} />
|
||||
{isExpanded && (
|
||||
<>
|
||||
<div className="flex max-w-[350px] flex-col gap-4">
|
||||
<div className="space-y-2">
|
||||
<Text variant="small-medium">Input</Text>
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={inputData}
|
||||
pinName="Input"
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy("input", inputData)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === "input" &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === "input" ? (
|
||||
<CheckIcon size={12} className="text-green-600" />
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
<ContentRenderer value={inputData} shortContent={false} />
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={inputData}
|
||||
pinName="Input"
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy("input", inputData)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === "input" &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === "input" ? (
|
||||
<CheckIcon size={12} className="text-green-600" />
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{Object.entries(outputData)
|
||||
.slice(0, 2)
|
||||
.map(([key, value]) => (
|
||||
<div key={key} className="flex flex-col gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Pin:
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-700">
|
||||
{beautifyString(key)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="w-full space-y-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Data:
|
||||
</Text>
|
||||
<div className="relative space-y-2">
|
||||
{value.map((item, index) => (
|
||||
<div key={index}>
|
||||
<ContentRenderer value={item} shortContent={true} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={value}
|
||||
pinName={key}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy(key, value)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === key &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === key ? (
|
||||
<CheckIcon size={12} className="text-green-600" />
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
{Object.entries(outputData)
|
||||
.slice(0, 2)
|
||||
.map(([key, value]) => (
|
||||
<div key={key} className="flex flex-col gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Text
|
||||
variant="small-medium"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Pin:
|
||||
</Text>
|
||||
<Text variant="small" className="text-slate-700">
|
||||
{beautifyString(key)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="w-full space-y-2">
|
||||
<Text
|
||||
variant="small"
|
||||
className="!font-semibold text-slate-600"
|
||||
>
|
||||
Data:
|
||||
</Text>
|
||||
<div className="relative space-y-2">
|
||||
{value.map((item, index) => (
|
||||
<div key={index}>
|
||||
<ContentRenderer value={item} shortContent={true} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className="mt-1 flex justify-end gap-1">
|
||||
<NodeDataViewer
|
||||
data={value}
|
||||
pinName={key}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="small"
|
||||
onClick={() => handleCopy(key, value)}
|
||||
className={cn(
|
||||
"h-fit min-w-0 gap-1.5 border border-zinc-200 p-2 text-black hover:text-slate-900",
|
||||
copiedKey === key &&
|
||||
"border-green-400 bg-green-100 hover:border-green-400 hover:bg-green-200",
|
||||
)}
|
||||
>
|
||||
{copiedKey === key ? (
|
||||
<CheckIcon size={12} className="text-green-600" />
|
||||
) : (
|
||||
<CopyIcon size={12} />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{Object.keys(outputData).length > 2 && (
|
||||
<ViewMoreData
|
||||
outputData={outputData}
|
||||
execId={executionResultId}
|
||||
/>
|
||||
)}
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
</Accordion>
|
||||
{Object.keys(outputData).length > 2 && (
|
||||
<ViewMoreData outputData={outputData} execId={executionResultId} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useShallow } from "zustand/react/shallow";
|
||||
import { useState } from "react";
|
||||
|
||||
export const useNodeOutput = (nodeId: string) => {
|
||||
const [isExpanded, setIsExpanded] = useState(true);
|
||||
const [copiedKey, setCopiedKey] = useState<string | null>(null);
|
||||
const { toast } = useToast();
|
||||
|
||||
@@ -36,10 +37,13 @@ export const useNodeOutput = (nodeId: string) => {
|
||||
}
|
||||
};
|
||||
return {
|
||||
outputData,
|
||||
inputData,
|
||||
copiedKey,
|
||||
handleCopy,
|
||||
outputData: outputData,
|
||||
inputData: inputData,
|
||||
isExpanded: isExpanded,
|
||||
setIsExpanded: setIsExpanded,
|
||||
copiedKey: copiedKey,
|
||||
setCopiedKey: setCopiedKey,
|
||||
handleCopy: handleCopy,
|
||||
executionResultId: nodeExecutionResult?.node_exec_id,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -187,38 +187,3 @@ export const getTypeDisplayInfo = (schema: any) => {
|
||||
hexColor,
|
||||
};
|
||||
};
|
||||
|
||||
export function getEdgeColorFromOutputType(
|
||||
outputSchema: RJSFSchema | undefined,
|
||||
sourceHandle: string,
|
||||
): { colorClass: string; hexColor: string } {
|
||||
const defaultColor = {
|
||||
colorClass: "stroke-zinc-500/50",
|
||||
hexColor: "#6b7280",
|
||||
};
|
||||
|
||||
if (!outputSchema?.properties) return defaultColor;
|
||||
|
||||
const properties = outputSchema.properties as Record<string, unknown>;
|
||||
const handleParts = sourceHandle.split("_#_");
|
||||
let currentSchema: Record<string, unknown> = properties;
|
||||
|
||||
for (let i = 0; i < handleParts.length; i++) {
|
||||
const part = handleParts[i];
|
||||
const fieldSchema = currentSchema[part] as Record<string, unknown>;
|
||||
if (!fieldSchema) return defaultColor;
|
||||
|
||||
if (i === handleParts.length - 1) {
|
||||
const { hexColor, colorClass } = getTypeDisplayInfo(fieldSchema);
|
||||
return { colorClass: colorClass.replace("!text-", "stroke-"), hexColor };
|
||||
}
|
||||
|
||||
if (fieldSchema.properties) {
|
||||
currentSchema = fieldSchema.properties as Record<string, unknown>;
|
||||
} else {
|
||||
return defaultColor;
|
||||
}
|
||||
}
|
||||
|
||||
return defaultColor;
|
||||
}
|
||||
|
||||
@@ -1,32 +1,7 @@
|
||||
type IconOptions = {
|
||||
size?: number;
|
||||
color?: string;
|
||||
};
|
||||
|
||||
const DEFAULT_SIZE = 16;
|
||||
const DEFAULT_COLOR = "#52525b"; // zinc-600
|
||||
|
||||
const iconPaths = {
|
||||
ClickIcon: `M88,24V16a8,8,0,0,1,16,0v8a8,8,0,0,1-16,0ZM16,104h8a8,8,0,0,0,0-16H16a8,8,0,0,0,0,16ZM124.42,39.16a8,8,0,0,0,10.74-3.58l8-16a8,8,0,0,0-14.31-7.16l-8,16A8,8,0,0,0,124.42,39.16Zm-96,81.69-16,8a8,8,0,0,0,7.16,14.31l16-8a8,8,0,1,0-7.16-14.31ZM219.31,184a16,16,0,0,1,0,22.63l-12.68,12.68a16,16,0,0,1-22.63,0L132.7,168,115,214.09c0,.1-.08.21-.13.32a15.83,15.83,0,0,1-14.6,9.59l-.79,0a15.83,15.83,0,0,1-14.41-11L32.8,52.92A16,16,0,0,1,52.92,32.8L213,85.07a16,16,0,0,1,1.41,29.8l-.32.13L168,132.69ZM208,195.31,156.69,144h0a16,16,0,0,1,4.93-26l.32-.14,45.95-17.64L48,48l52.2,159.86,17.65-46c0-.11.08-.22.13-.33a16,16,0,0,1,11.69-9.34,16.72,16.72,0,0,1,3-.28,16,16,0,0,1,11.3,4.69L195.31,208Z`,
|
||||
Keyboard: `M224,48H32A16,16,0,0,0,16,64V192a16,16,0,0,0,16,16H224a16,16,0,0,0,16-16V64A16,16,0,0,0,224,48Zm0,144H32V64H224V192Zm-16-64a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,128Zm0-32a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,96ZM72,160a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16h8A8,8,0,0,1,72,160Zm96,0a8,8,0,0,1-8,8H96a8,8,0,0,1,0-16h64A8,8,0,0,1,168,160Zm40,0a8,8,0,0,1-8,8h-8a8,8,0,0,1,0-16h8A8,8,0,0,1,208,160Z`,
|
||||
Drag: `M188,80a27.79,27.79,0,0,0-13.36,3.4,28,28,0,0,0-46.64-11A28,28,0,0,0,80,92v20H68a28,28,0,0,0-28,28v12a88,88,0,0,0,176,0V108A28,28,0,0,0,188,80Zm12,72a72,72,0,0,1-144,0V140a12,12,0,0,1,12-12H80v24a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V108a12,12,0,0,1,24,0Z`,
|
||||
};
|
||||
|
||||
function createIcon(path: string, options: IconOptions = {}): string {
|
||||
const size = options.size ?? DEFAULT_SIZE;
|
||||
const color = options.color ?? DEFAULT_COLOR;
|
||||
return `<svg xmlns="http://www.w3.org/2000/svg" width="${size}" height="${size}" fill="${color}" viewBox="0 0 256 256"><path d="${path}"></path></svg>`;
|
||||
}
|
||||
// These are SVG Phosphor icons
|
||||
|
||||
export const ICONS = {
|
||||
ClickIcon: createIcon(iconPaths.ClickIcon),
|
||||
Keyboard: createIcon(iconPaths.Keyboard),
|
||||
Drag: createIcon(iconPaths.Drag),
|
||||
ClickIcon: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M88,24V16a8,8,0,0,1,16,0v8a8,8,0,0,1-16,0ZM16,104h8a8,8,0,0,0,0-16H16a8,8,0,0,0,0,16ZM124.42,39.16a8,8,0,0,0,10.74-3.58l8-16a8,8,0,0,0-14.31-7.16l-8,16A8,8,0,0,0,124.42,39.16Zm-96,81.69-16,8a8,8,0,0,0,7.16,14.31l16-8a8,8,0,1,0-7.16-14.31ZM219.31,184a16,16,0,0,1,0,22.63l-12.68,12.68a16,16,0,0,1-22.63,0L132.7,168,115,214.09c0,.1-.08.21-.13.32a15.83,15.83,0,0,1-14.6,9.59l-.79,0a15.83,15.83,0,0,1-14.41-11L32.8,52.92A16,16,0,0,1,52.92,32.8L213,85.07a16,16,0,0,1,1.41,29.8l-.32.13L168,132.69ZM208,195.31,156.69,144h0a16,16,0,0,1,4.93-26l.32-.14,45.95-17.64L48,48l52.2,159.86,17.65-46c0-.11.08-.22.13-.33a16,16,0,0,1,11.69-9.34,16.72,16.72,0,0,1,3-.28,16,16,0,0,1,11.3,4.69L195.31,208Z"></path></svg>`,
|
||||
Keyboard: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M224,48H32A16,16,0,0,0,16,64V192a16,16,0,0,0,16,16H224a16,16,0,0,0,16-16V64A16,16,0,0,0,224,48Zm0,144H32V64H224V192Zm-16-64a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,128Zm0-32a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16H200A8,8,0,0,1,208,96ZM72,160a8,8,0,0,1-8,8H56a8,8,0,0,1,0-16h8A8,8,0,0,1,72,160Zm96,0a8,8,0,0,1-8,8H96a8,8,0,0,1,0-16h64A8,8,0,0,1,168,160Zm40,0a8,8,0,0,1-8,8h-8a8,8,0,0,1,0-16h8A8,8,0,0,1,208,160Z"></path></svg>`,
|
||||
Drag: `<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="#000000" viewBox="0 0 256 256"><path d="M188,80a27.79,27.79,0,0,0-13.36,3.4,28,28,0,0,0-46.64-11A28,28,0,0,0,80,92v20H68a28,28,0,0,0-28,28v12a88,88,0,0,0,176,0V108A28,28,0,0,0,188,80Zm12,72a72,72,0,0,1-144,0V140a12,12,0,0,1,12-12H80v24a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V92a12,12,0,0,1,24,0v28a8,8,0,0,0,16,0V108a12,12,0,0,1,24,0Z"></path></svg>`,
|
||||
};
|
||||
|
||||
export function getIcon(
|
||||
name: keyof typeof iconPaths,
|
||||
options?: IconOptions,
|
||||
): string {
|
||||
return createIcon(iconPaths[name], options);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import {
|
||||
} from "./helpers";
|
||||
import { useNodeStore } from "../../../stores/nodeStore";
|
||||
import { useEdgeStore } from "../../../stores/edgeStore";
|
||||
import { useTutorialStore } from "../../../stores/tutorialStore";
|
||||
|
||||
let isTutorialLoading = false;
|
||||
let tutorialLoadingCallback: ((loading: boolean) => void) | null = null;
|
||||
@@ -61,14 +60,12 @@ export const startTutorial = async () => {
|
||||
handleTutorialComplete();
|
||||
removeTutorialStyles();
|
||||
clearPrefetchedBlocks();
|
||||
useTutorialStore.getState().setIsTutorialRunning(false);
|
||||
});
|
||||
|
||||
tour.on("cancel", () => {
|
||||
handleTutorialCancel(tour);
|
||||
removeTutorialStyles();
|
||||
clearPrefetchedBlocks();
|
||||
useTutorialStore.getState().setIsTutorialRunning(false);
|
||||
});
|
||||
|
||||
for (const step of tour.steps) {
|
||||
|
||||
@@ -61,18 +61,12 @@ export const convertNodesPlusBlockInfoIntoCustomNodes = (
|
||||
return customNode;
|
||||
};
|
||||
|
||||
const isToolSourceName = (sourceName: string): boolean =>
|
||||
sourceName.startsWith("tools_^_");
|
||||
|
||||
const cleanupSourceName = (sourceName: string): string =>
|
||||
isToolSourceName(sourceName) ? "tools" : sourceName;
|
||||
|
||||
export const linkToCustomEdge = (link: Link): CustomEdge => ({
|
||||
id: link.id ?? "",
|
||||
type: "custom" as const,
|
||||
source: link.source_id,
|
||||
target: link.sink_id,
|
||||
sourceHandle: cleanupSourceName(link.source_name),
|
||||
sourceHandle: link.source_name,
|
||||
targetHandle: link.sink_name,
|
||||
data: {
|
||||
isStatic: link.is_static,
|
||||
|
||||
@@ -267,34 +267,23 @@ export function extractCredentialsNeeded(
|
||||
| undefined;
|
||||
if (missingCreds && Object.keys(missingCreds).length > 0) {
|
||||
const agentName = (setupInfo?.agent_name as string) || "this block";
|
||||
const credentials = Object.values(missingCreds).map((credInfo) => {
|
||||
// Normalize to array at boundary - prefer 'types' array, fall back to single 'type'
|
||||
const typesArray = credInfo.types as
|
||||
| Array<"api_key" | "oauth2" | "user_password" | "host_scoped">
|
||||
| undefined;
|
||||
const singleType =
|
||||
const credentials = Object.values(missingCreds).map((credInfo) => ({
|
||||
provider: (credInfo.provider as string) || "unknown",
|
||||
providerName:
|
||||
(credInfo.provider_name as string) ||
|
||||
(credInfo.provider as string) ||
|
||||
"Unknown Provider",
|
||||
credentialType:
|
||||
(credInfo.type as
|
||||
| "api_key"
|
||||
| "oauth2"
|
||||
| "user_password"
|
||||
| "host_scoped"
|
||||
| undefined) || "api_key";
|
||||
const credentialTypes =
|
||||
typesArray && typesArray.length > 0 ? typesArray : [singleType];
|
||||
|
||||
return {
|
||||
provider: (credInfo.provider as string) || "unknown",
|
||||
providerName:
|
||||
(credInfo.provider_name as string) ||
|
||||
(credInfo.provider as string) ||
|
||||
"Unknown Provider",
|
||||
credentialTypes,
|
||||
title:
|
||||
(credInfo.title as string) ||
|
||||
`${(credInfo.provider_name as string) || (credInfo.provider as string)} credentials`,
|
||||
scopes: credInfo.scopes as string[] | undefined,
|
||||
};
|
||||
});
|
||||
| "host_scoped") || "api_key",
|
||||
title:
|
||||
(credInfo.title as string) ||
|
||||
`${(credInfo.provider_name as string) || (credInfo.provider as string)} credentials`,
|
||||
scopes: credInfo.scopes as string[] | undefined,
|
||||
}));
|
||||
return {
|
||||
type: "credentials_needed",
|
||||
toolName,
|
||||
@@ -369,14 +358,11 @@ export function extractInputsNeeded(
|
||||
credentials.forEach((cred) => {
|
||||
const id = cred.id as string;
|
||||
if (id) {
|
||||
const credentialTypes = Array.isArray(cred.types)
|
||||
? cred.types
|
||||
: [(cred.type as string) || "api_key"];
|
||||
credentialsSchema[id] = {
|
||||
type: "object",
|
||||
properties: {},
|
||||
credentials_provider: [cred.provider as string],
|
||||
credentials_types: credentialTypes,
|
||||
credentials_types: [(cred.type as string) || "api_key"],
|
||||
credentials_scopes: cred.scopes as string[] | undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -9,9 +9,7 @@ import { useChatCredentialsSetup } from "./useChatCredentialsSetup";
|
||||
export interface CredentialInfo {
|
||||
provider: string;
|
||||
providerName: string;
|
||||
credentialTypes: Array<
|
||||
"api_key" | "oauth2" | "user_password" | "host_scoped"
|
||||
>;
|
||||
credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped";
|
||||
title: string;
|
||||
scopes?: string[];
|
||||
}
|
||||
@@ -32,7 +30,7 @@ function createSchemaFromCredentialInfo(
|
||||
type: "object",
|
||||
properties: {},
|
||||
credentials_provider: [credential.provider],
|
||||
credentials_types: credential.credentialTypes,
|
||||
credentials_types: [credential.credentialType],
|
||||
credentials_scopes: credential.scopes,
|
||||
discriminator: undefined,
|
||||
discriminator_mapping: undefined,
|
||||
|
||||
@@ -41,9 +41,7 @@ export type ChatMessageData =
|
||||
credentials: Array<{
|
||||
provider: string;
|
||||
providerName: string;
|
||||
credentialTypes: Array<
|
||||
"api_key" | "oauth2" | "user_password" | "host_scoped"
|
||||
>;
|
||||
credentialType: "api_key" | "oauth2" | "user_password" | "host_scoped";
|
||||
title: string;
|
||||
scopes?: string[];
|
||||
}>;
|
||||
|
||||
@@ -31,18 +31,10 @@ export function AgentSettingsModal({
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
handleSensitiveActionToggle,
|
||||
isPending,
|
||||
shouldShowToggle,
|
||||
} = useAgentSafeMode(agent);
|
||||
const { currentSafeMode, isPending, hasHITLBlocks, handleToggle } =
|
||||
useAgentSafeMode(agent);
|
||||
|
||||
if (!shouldShowToggle) return null;
|
||||
if (!hasHITLBlocks) return null;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
@@ -65,48 +57,23 @@ export function AgentSettingsModal({
|
||||
)}
|
||||
<Dialog.Content>
|
||||
<div className="space-y-6">
|
||||
{showHITLToggle && (
|
||||
<div className="flex w-full flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">
|
||||
Human-in-the-loop approval
|
||||
</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause at human-in-the-loop blocks and wait
|
||||
for your review before continuing
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={currentHITLSafeMode || false}
|
||||
onCheckedChange={handleHITLToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
<div className="flex w-full flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">Require human approval</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause and wait for your review before
|
||||
continuing
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={currentSafeMode || false}
|
||||
onCheckedChange={handleToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{showSensitiveActionToggle && (
|
||||
<div className="flex w-full flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">
|
||||
Sensitive action approval
|
||||
</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause at sensitive action blocks and wait for
|
||||
your review before continuing
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={currentSensitiveActionSafeMode}
|
||||
onCheckedChange={handleSensitiveActionToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
|
||||
@@ -14,10 +14,6 @@ import {
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { ScheduleAgentModal } from "../ScheduleAgentModal/ScheduleAgentModal";
|
||||
import {
|
||||
AIAgentSafetyPopup,
|
||||
useAIAgentSafetyPopup,
|
||||
} from "./components/AIAgentSafetyPopup/AIAgentSafetyPopup";
|
||||
import { ModalHeader } from "./components/ModalHeader/ModalHeader";
|
||||
import { ModalRunSection } from "./components/ModalRunSection/ModalRunSection";
|
||||
import { RunActions } from "./components/RunActions/RunActions";
|
||||
@@ -87,17 +83,8 @@ export function RunAgentModal({
|
||||
|
||||
const [isScheduleModalOpen, setIsScheduleModalOpen] = useState(false);
|
||||
const [hasOverflow, setHasOverflow] = useState(false);
|
||||
const [isSafetyPopupOpen, setIsSafetyPopupOpen] = useState(false);
|
||||
const [pendingRunAction, setPendingRunAction] = useState<(() => void) | null>(
|
||||
null,
|
||||
);
|
||||
const contentRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const { shouldShowPopup, dismissPopup } = useAIAgentSafetyPopup(
|
||||
agent.has_sensitive_action,
|
||||
agent.has_human_in_the_loop,
|
||||
);
|
||||
|
||||
const hasAnySetupFields =
|
||||
Object.keys(agentInputFields || {}).length > 0 ||
|
||||
Object.keys(agentCredentialsInputFields || {}).length > 0;
|
||||
@@ -178,24 +165,6 @@ export function RunAgentModal({
|
||||
onScheduleCreated?.(schedule);
|
||||
}
|
||||
|
||||
function handleRunWithSafetyCheck() {
|
||||
if (shouldShowPopup) {
|
||||
setPendingRunAction(() => handleRun);
|
||||
setIsSafetyPopupOpen(true);
|
||||
} else {
|
||||
handleRun();
|
||||
}
|
||||
}
|
||||
|
||||
function handleSafetyPopupAcknowledge() {
|
||||
setIsSafetyPopupOpen(false);
|
||||
dismissPopup();
|
||||
if (pendingRunAction) {
|
||||
pendingRunAction();
|
||||
setPendingRunAction(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Dialog
|
||||
@@ -279,7 +248,7 @@ export function RunAgentModal({
|
||||
)}
|
||||
<RunActions
|
||||
defaultRunType={defaultRunType}
|
||||
onRun={handleRunWithSafetyCheck}
|
||||
onRun={handleRun}
|
||||
isExecuting={isExecuting}
|
||||
isSettingUpTrigger={isSettingUpTrigger}
|
||||
isRunReady={allRequiredInputsAreSet}
|
||||
@@ -297,11 +266,6 @@ export function RunAgentModal({
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
|
||||
<AIAgentSafetyPopup
|
||||
isOpen={isSafetyPopupOpen}
|
||||
onAcknowledge={handleSafetyPopupAcknowledge}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { ShieldCheckIcon } from "@phosphor-icons/react";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
|
||||
interface Props {
|
||||
onAcknowledge: () => void;
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
export function AIAgentSafetyPopup({ onAcknowledge, isOpen }: Props) {
|
||||
function handleAcknowledge() {
|
||||
// Mark popup as shown so it won't appear again
|
||||
storage.set(Key.AI_AGENT_SAFETY_POPUP_SHOWN, "true");
|
||||
onAcknowledge();
|
||||
}
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
controlled={{ isOpen, set: () => {} }}
|
||||
styling={{ maxWidth: "480px" }}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col items-center p-6 text-center">
|
||||
<div className="mb-6 flex h-16 w-16 items-center justify-center rounded-full bg-blue-50">
|
||||
<ShieldCheckIcon
|
||||
weight="fill"
|
||||
size={32}
|
||||
className="text-blue-600"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Text variant="h3" className="mb-4">
|
||||
Safety Checks Enabled
|
||||
</Text>
|
||||
|
||||
<Text variant="body" className="mb-2 text-zinc-700">
|
||||
AI-generated agents may take actions that affect your data or
|
||||
external systems.
|
||||
</Text>
|
||||
|
||||
<Text variant="body" className="mb-8 text-zinc-700">
|
||||
AutoGPT includes safety checks so you'll always have the
|
||||
opportunity to review and approve sensitive actions before they
|
||||
happen.
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
variant="primary"
|
||||
size="large"
|
||||
className="w-full"
|
||||
onClick={handleAcknowledge}
|
||||
>
|
||||
Got it
|
||||
</Button>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
export function useAIAgentSafetyPopup(
|
||||
hasSensitiveAction: boolean,
|
||||
hasHumanInTheLoop: boolean,
|
||||
) {
|
||||
const [shouldShowPopup, setShouldShowPopup] = useState(false);
|
||||
const [hasChecked, setHasChecked] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Only check once after mount (to avoid SSR issues)
|
||||
if (hasChecked) return;
|
||||
|
||||
const hasSeenPopup =
|
||||
storage.get(Key.AI_AGENT_SAFETY_POPUP_SHOWN) === "true";
|
||||
const isRelevantAgent = hasSensitiveAction || hasHumanInTheLoop;
|
||||
|
||||
setShouldShowPopup(!hasSeenPopup && isRelevantAgent);
|
||||
setHasChecked(true);
|
||||
}, [hasSensitiveAction, hasHumanInTheLoop, hasChecked]);
|
||||
|
||||
const dismissPopup = useCallback(() => {
|
||||
setShouldShowPopup(false);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
shouldShowPopup,
|
||||
dismissPopup,
|
||||
};
|
||||
}
|
||||
@@ -5,37 +5,30 @@ import {
|
||||
AccordionItem,
|
||||
AccordionTrigger,
|
||||
} from "@/components/molecules/Accordion/Accordion";
|
||||
import {
|
||||
CredentialsMetaInput,
|
||||
CredentialsType,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
||||
import { SlidersHorizontalIcon } from "@phosphor-icons/react";
|
||||
import { SlidersHorizontal } from "@phosphor-icons/react";
|
||||
import { useContext, useEffect, useMemo, useRef } from "react";
|
||||
import { useRunAgentModalContext } from "../../context";
|
||||
import {
|
||||
areSystemCredentialProvidersLoading,
|
||||
CredentialField,
|
||||
findSavedCredentialByProviderAndType,
|
||||
hasMissingRequiredSystemCredentials,
|
||||
splitCredentialFieldsBySystem,
|
||||
} from "./helpers";
|
||||
} from "../helpers";
|
||||
|
||||
type Props = {
|
||||
credentialFields: CredentialField[];
|
||||
requiredCredentials: Set<string>;
|
||||
inputCredentials: Record<string, CredentialsMetaInput | undefined>;
|
||||
inputValues: Record<string, any>;
|
||||
onCredentialChange: (key: string, value?: CredentialsMetaInput) => void;
|
||||
};
|
||||
|
||||
export function CredentialsGroupedView({
|
||||
credentialFields,
|
||||
requiredCredentials,
|
||||
inputCredentials,
|
||||
inputValues,
|
||||
onCredentialChange,
|
||||
}: Props) {
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
const { inputCredentials, setInputCredentialsValue, inputValues } =
|
||||
useRunAgentModalContext();
|
||||
|
||||
const { userCredentialFields, systemCredentialFields } = useMemo(
|
||||
() =>
|
||||
@@ -94,11 +87,11 @@ export function CredentialsGroupedView({
|
||||
);
|
||||
|
||||
if (savedCredential) {
|
||||
onCredentialChange(key, {
|
||||
setInputCredentialsValue(key, {
|
||||
id: savedCredential.id,
|
||||
provider: savedCredential.provider,
|
||||
type: savedCredential.type as CredentialsType,
|
||||
title: savedCredential.title,
|
||||
type: savedCredential.type,
|
||||
title: (savedCredential as { title?: string }).title,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -110,7 +103,7 @@ export function CredentialsGroupedView({
|
||||
systemCredentialFields,
|
||||
requiredCredentials,
|
||||
inputCredentials,
|
||||
onCredentialChange,
|
||||
setInputCredentialsValue,
|
||||
isLoadingProviders,
|
||||
]);
|
||||
|
||||
@@ -130,7 +123,7 @@ export function CredentialsGroupedView({
|
||||
}
|
||||
selectedCredentials={selectedCred}
|
||||
onSelectCredentials={(value) => {
|
||||
onCredentialChange(key, value);
|
||||
setInputCredentialsValue(key, value);
|
||||
}}
|
||||
siblingInputs={inputValues}
|
||||
isOptional={!requiredCredentials.has(key)}
|
||||
@@ -150,8 +143,7 @@ export function CredentialsGroupedView({
|
||||
<AccordionItem value="system-credentials" className="border-none">
|
||||
<AccordionTrigger className="py-2 text-sm text-muted-foreground hover:no-underline">
|
||||
<div className="flex items-center gap-1">
|
||||
<SlidersHorizontalIcon size={16} weight="bold" /> System
|
||||
credentials
|
||||
<SlidersHorizontal size={16} weight="bold" /> System credentials
|
||||
{hasMissingSystemCredentials && (
|
||||
<span className="text-destructive">(missing)</span>
|
||||
)}
|
||||
@@ -171,7 +163,7 @@ export function CredentialsGroupedView({
|
||||
}
|
||||
selectedCredentials={selectedCred}
|
||||
onSelectCredentials={(value) => {
|
||||
onCredentialChange(key, value);
|
||||
setInputCredentialsValue(key, value);
|
||||
}}
|
||||
siblingInputs={inputValues}
|
||||
isOptional={!requiredCredentials.has(key)}
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { CredentialsGroupedView } from "@/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { InformationTooltip } from "@/components/molecules/InformationTooltip/InformationTooltip";
|
||||
import { useMemo } from "react";
|
||||
import { RunAgentInputs } from "../../../RunAgentInputs/RunAgentInputs";
|
||||
import { useRunAgentModalContext } from "../../context";
|
||||
import { CredentialsGroupedView } from "../CredentialsGroupedView/CredentialsGroupedView";
|
||||
import { ModalSection } from "../ModalSection/ModalSection";
|
||||
import { WebhookTriggerBanner } from "../WebhookTriggerBanner/WebhookTriggerBanner";
|
||||
|
||||
@@ -19,8 +19,6 @@ export function ModalRunSection() {
|
||||
setInputValue,
|
||||
agentInputFields,
|
||||
agentCredentialsInputFields,
|
||||
inputCredentials,
|
||||
setInputCredentialsValue,
|
||||
} = useRunAgentModalContext();
|
||||
|
||||
const inputFields = Object.entries(agentInputFields || {});
|
||||
@@ -104,9 +102,6 @@ export function ModalRunSection() {
|
||||
<CredentialsGroupedView
|
||||
credentialFields={credentialFields}
|
||||
requiredCredentials={requiredCredentials}
|
||||
inputCredentials={inputCredentials}
|
||||
inputValues={inputValues}
|
||||
onCredentialChange={setInputCredentialsValue}
|
||||
/>
|
||||
</ModalSection>
|
||||
) : null}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CredentialsProvidersContextType } from "@/providers/agent-credentials/credentials-provider";
|
||||
import { getSystemCredentials } from "../../helpers";
|
||||
import { getSystemCredentials } from "../../../../../../../../../../../components/contextual/CredentialsInput/helpers";
|
||||
|
||||
export type CredentialField = [string, any];
|
||||
|
||||
@@ -5,104 +5,48 @@ import { Graph } from "@/lib/autogpt-server-api/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ShieldCheckIcon, ShieldIcon } from "@phosphor-icons/react";
|
||||
import { useAgentSafeMode } from "@/hooks/useAgentSafeMode";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
|
||||
interface Props {
|
||||
graph: GraphModel | LibraryAgent | Graph;
|
||||
className?: string;
|
||||
fullWidth?: boolean;
|
||||
}
|
||||
|
||||
interface SafeModeIconButtonProps {
|
||||
isEnabled: boolean;
|
||||
label: string;
|
||||
tooltipEnabled: string;
|
||||
tooltipDisabled: string;
|
||||
onToggle: () => void;
|
||||
isPending: boolean;
|
||||
}
|
||||
|
||||
function SafeModeIconButton({
|
||||
isEnabled,
|
||||
label,
|
||||
tooltipEnabled,
|
||||
tooltipDisabled,
|
||||
onToggle,
|
||||
isPending,
|
||||
}: SafeModeIconButtonProps) {
|
||||
return (
|
||||
<Tooltip delayDuration={100}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label={`${label}: ${isEnabled ? "ON" : "OFF"}. ${isEnabled ? tooltipEnabled : tooltipDisabled}`}
|
||||
onClick={onToggle}
|
||||
disabled={isPending}
|
||||
className={cn(isPending ? "opacity-0" : "opacity-100")}
|
||||
>
|
||||
{isEnabled ? (
|
||||
<ShieldCheckIcon weight="bold" size={16} />
|
||||
) : (
|
||||
<ShieldIcon weight="bold" size={16} />
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<div className="text-center">
|
||||
<div className="font-medium">
|
||||
{label}: {isEnabled ? "ON" : "OFF"}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-muted-foreground">
|
||||
{isEnabled ? tooltipEnabled : tooltipDisabled}
|
||||
</div>
|
||||
</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
|
||||
export function SafeModeToggle({ graph, className }: Props) {
|
||||
export function SafeModeToggle({ graph }: Props) {
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
handleSensitiveActionToggle,
|
||||
currentSafeMode,
|
||||
isPending,
|
||||
shouldShowToggle,
|
||||
isStateUndetermined,
|
||||
handleToggle,
|
||||
} = useAgentSafeMode(graph);
|
||||
|
||||
if (!shouldShowToggle) {
|
||||
if (!shouldShowToggle || isStateUndetermined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex gap-1", className)}>
|
||||
{showHITLToggle && (
|
||||
<SafeModeIconButton
|
||||
isEnabled={currentHITLSafeMode}
|
||||
label="Human-in-the-loop"
|
||||
tooltipEnabled="The agent will pause at human-in-the-loop blocks and wait for your approval"
|
||||
tooltipDisabled="Human-in-the-loop blocks will proceed automatically"
|
||||
onToggle={handleHITLToggle}
|
||||
isPending={isPending}
|
||||
/>
|
||||
<Button
|
||||
variant="icon"
|
||||
key={graph.id}
|
||||
size="icon"
|
||||
aria-label={
|
||||
currentSafeMode!
|
||||
? "Safe Mode: ON. Human in the loop blocks require manual review"
|
||||
: "Safe Mode: OFF. Human in the loop blocks proceed automatically"
|
||||
}
|
||||
onClick={handleToggle}
|
||||
className={cn(isPending ? "opacity-0" : "opacity-100")}
|
||||
>
|
||||
{currentSafeMode! ? (
|
||||
<>
|
||||
<ShieldCheckIcon weight="bold" size={16} />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ShieldIcon weight="bold" size={16} />
|
||||
</>
|
||||
)}
|
||||
{showSensitiveActionToggle && (
|
||||
<SafeModeIconButton
|
||||
isEnabled={currentSensitiveActionSafeMode}
|
||||
label="Sensitive actions"
|
||||
tooltipEnabled="The agent will pause at sensitive action blocks and wait for your approval"
|
||||
tooltipDisabled="Sensitive action blocks will proceed automatically"
|
||||
onToggle={handleSensitiveActionToggle}
|
||||
isPending={isPending}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -13,16 +13,8 @@ interface Props {
|
||||
}
|
||||
|
||||
export function SelectedSettingsView({ agent, onClearSelectedRun }: Props) {
|
||||
const {
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
handleHITLToggle,
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
handleSensitiveActionToggle,
|
||||
isPending,
|
||||
shouldShowToggle,
|
||||
} = useAgentSafeMode(agent);
|
||||
const { currentSafeMode, isPending, hasHITLBlocks, handleToggle } =
|
||||
useAgentSafeMode(agent);
|
||||
|
||||
return (
|
||||
<SelectedViewLayout agent={agent}>
|
||||
@@ -42,51 +34,24 @@ export function SelectedSettingsView({ agent, onClearSelectedRun }: Props) {
|
||||
</div>
|
||||
|
||||
<div className={`${AGENT_LIBRARY_SECTION_PADDING_X} space-y-6`}>
|
||||
{shouldShowToggle ? (
|
||||
<>
|
||||
{showHITLToggle && (
|
||||
<div className="flex w-full max-w-2xl flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">
|
||||
Human-in-the-loop approval
|
||||
</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause at human-in-the-loop blocks and
|
||||
wait for your review before continuing
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={currentHITLSafeMode || false}
|
||||
onCheckedChange={handleHITLToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
{hasHITLBlocks ? (
|
||||
<div className="flex w-full max-w-2xl flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">Require human approval</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause and wait for your review before
|
||||
continuing
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
{showSensitiveActionToggle && (
|
||||
<div className="flex w-full max-w-2xl flex-col items-start gap-4 rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<div className="flex w-full items-start justify-between gap-4">
|
||||
<div className="flex-1">
|
||||
<Text variant="large-semibold">
|
||||
Sensitive action approval
|
||||
</Text>
|
||||
<Text variant="large" className="mt-1 text-zinc-900">
|
||||
The agent will pause at sensitive action blocks and wait
|
||||
for your review before continuing
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={currentSensitiveActionSafeMode}
|
||||
onCheckedChange={handleSensitiveActionToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
<Switch
|
||||
checked={currentSafeMode || false}
|
||||
onCheckedChange={handleToggle}
|
||||
disabled={isPending}
|
||||
className="mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="rounded-xl border border-zinc-100 bg-white p-6">
|
||||
<Text variant="body" className="text-muted-foreground">
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { FileInput } from "@/components/atoms/FileInput/FileInput";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import {
|
||||
Form,
|
||||
@@ -121,7 +120,7 @@ export default function LibraryUploadAgentDialog() {
|
||||
>
|
||||
{isUploading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<LoadingSpinner size="small" className="text-white" />
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-b-2 border-t-2 border-white"></div>
|
||||
<span>Uploading...</span>
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -6383,11 +6383,6 @@
|
||||
"title": "Has Human In The Loop",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_sensitive_action": {
|
||||
"type": "boolean",
|
||||
"title": "Has Sensitive Action",
|
||||
"readOnly": true
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||
@@ -6404,7 +6399,6 @@
|
||||
"output_schema",
|
||||
"has_external_trigger",
|
||||
"has_human_in_the_loop",
|
||||
"has_sensitive_action",
|
||||
"trigger_setup_info"
|
||||
],
|
||||
"title": "BaseGraph"
|
||||
@@ -7635,11 +7629,6 @@
|
||||
"title": "Has Human In The Loop",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_sensitive_action": {
|
||||
"type": "boolean",
|
||||
"title": "Has Sensitive Action",
|
||||
"readOnly": true
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||
@@ -7663,7 +7652,6 @@
|
||||
"output_schema",
|
||||
"has_external_trigger",
|
||||
"has_human_in_the_loop",
|
||||
"has_sensitive_action",
|
||||
"trigger_setup_info",
|
||||
"credentials_input_schema"
|
||||
],
|
||||
@@ -7742,11 +7730,6 @@
|
||||
"title": "Has Human In The Loop",
|
||||
"readOnly": true
|
||||
},
|
||||
"has_sensitive_action": {
|
||||
"type": "boolean",
|
||||
"title": "Has Sensitive Action",
|
||||
"readOnly": true
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||
@@ -7771,7 +7754,6 @@
|
||||
"output_schema",
|
||||
"has_external_trigger",
|
||||
"has_human_in_the_loop",
|
||||
"has_sensitive_action",
|
||||
"trigger_setup_info",
|
||||
"credentials_input_schema"
|
||||
],
|
||||
@@ -7780,14 +7762,8 @@
|
||||
"GraphSettings": {
|
||||
"properties": {
|
||||
"human_in_the_loop_safe_mode": {
|
||||
"type": "boolean",
|
||||
"title": "Human In The Loop Safe Mode",
|
||||
"default": true
|
||||
},
|
||||
"sensitive_action_safe_mode": {
|
||||
"type": "boolean",
|
||||
"title": "Sensitive Action Safe Mode",
|
||||
"default": false
|
||||
"anyOf": [{ "type": "boolean" }, { "type": "null" }],
|
||||
"title": "Human In The Loop Safe Mode"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -7945,16 +7921,6 @@
|
||||
"title": "Has External Trigger",
|
||||
"description": "Whether the agent has an external trigger (e.g. webhook) node"
|
||||
},
|
||||
"has_human_in_the_loop": {
|
||||
"type": "boolean",
|
||||
"title": "Has Human In The Loop",
|
||||
"description": "Whether the agent has human-in-the-loop blocks"
|
||||
},
|
||||
"has_sensitive_action": {
|
||||
"type": "boolean",
|
||||
"title": "Has Sensitive Action",
|
||||
"description": "Whether the agent has sensitive action blocks"
|
||||
},
|
||||
"trigger_setup_info": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/GraphTriggerInfo" },
|
||||
@@ -8001,8 +7967,6 @@
|
||||
"output_schema",
|
||||
"credentials_input_schema",
|
||||
"has_external_trigger",
|
||||
"has_human_in_the_loop",
|
||||
"has_sensitive_action",
|
||||
"new_output",
|
||||
"can_access_graph",
|
||||
"is_latest_version",
|
||||
@@ -8829,11 +8793,6 @@
|
||||
"title": "Graph Version",
|
||||
"description": "Graph version"
|
||||
},
|
||||
"node_id": {
|
||||
"type": "string",
|
||||
"title": "Node Id",
|
||||
"description": "Node ID in the graph definition"
|
||||
},
|
||||
"payload": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
@@ -8907,7 +8866,6 @@
|
||||
"graph_exec_id",
|
||||
"graph_id",
|
||||
"graph_version",
|
||||
"node_id",
|
||||
"payload",
|
||||
"editable",
|
||||
"status",
|
||||
@@ -9431,12 +9389,6 @@
|
||||
"type": "array",
|
||||
"title": "Reviews",
|
||||
"description": "All reviews with their approval status, data, and messages"
|
||||
},
|
||||
"auto_approve_node_ids": {
|
||||
"items": { "type": "string" },
|
||||
"type": "array",
|
||||
"title": "Auto Approve Node Ids",
|
||||
"description": "List of node IDs (from the graph definition) to auto-approve for the remainder of this execution. Future reviews from these specific nodes will be automatically approved. This only affects the current execution run."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -37,7 +37,7 @@ export function PendingReviewsList({
|
||||
>({});
|
||||
|
||||
const [pendingAction, setPendingAction] = useState<
|
||||
"approve" | "approve-all" | "reject" | null
|
||||
"approve" | "reject" | null
|
||||
>(null);
|
||||
|
||||
const { toast } = useToast();
|
||||
@@ -92,10 +92,7 @@ export function PendingReviewsList({
|
||||
setReviewMessageMap((prev) => ({ ...prev, [nodeExecId]: message }));
|
||||
}
|
||||
|
||||
function processReviews(
|
||||
approved: boolean,
|
||||
autoApproveFutureActions: boolean = false,
|
||||
) {
|
||||
function processReviews(approved: boolean) {
|
||||
if (reviews.length === 0) {
|
||||
toast({
|
||||
title: "No reviews to process",
|
||||
@@ -105,13 +102,7 @@ export function PendingReviewsList({
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingAction(
|
||||
autoApproveFutureActions
|
||||
? "approve-all"
|
||||
: approved
|
||||
? "approve"
|
||||
: "reject",
|
||||
);
|
||||
setPendingAction(approved ? "approve" : "reject");
|
||||
const reviewItems = [];
|
||||
|
||||
for (const review of reviews) {
|
||||
@@ -143,15 +134,9 @@ export function PendingReviewsList({
|
||||
});
|
||||
}
|
||||
|
||||
// Collect unique node_ids if auto-approving future actions
|
||||
const autoApproveNodeIds = autoApproveFutureActions
|
||||
? [...new Set(reviews.map((r) => r.node_id))]
|
||||
: [];
|
||||
|
||||
reviewActionMutation.mutate({
|
||||
data: {
|
||||
reviews: reviewItems,
|
||||
auto_approve_node_ids: autoApproveNodeIds,
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -206,8 +191,12 @@ export function PendingReviewsList({
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<div className="space-y-7">
|
||||
<Text variant="body" className="text-textGrey">
|
||||
Note: Changes you make here apply only to this task
|
||||
</Text>
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
onClick={() => processReviews(true)}
|
||||
disabled={reviewActionMutation.isPending || reviews.length === 0}
|
||||
@@ -219,17 +208,6 @@ export function PendingReviewsList({
|
||||
>
|
||||
Approve
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => processReviews(true, true)}
|
||||
disabled={reviewActionMutation.isPending || reviews.length === 0}
|
||||
variant="secondary"
|
||||
className="flex items-center justify-center gap-2 rounded-full px-4 py-3"
|
||||
loading={
|
||||
pendingAction === "approve-all" && reviewActionMutation.isPending
|
||||
}
|
||||
>
|
||||
Approve all future actions
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => processReviews(false)}
|
||||
disabled={reviewActionMutation.isPending || reviews.length === 0}
|
||||
@@ -242,11 +220,6 @@ export function PendingReviewsList({
|
||||
Reject
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<Text variant="small" className="text-textGrey">
|
||||
You can turn auto-approval on or off anytime in this agent's
|
||||
settings.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -35,13 +35,12 @@ export const CredentialFieldTitle = (props: {
|
||||
uiOptions,
|
||||
);
|
||||
|
||||
const provider = getCredentialProviderFromSchema(
|
||||
useNodeStore.getState().getHardCodedValues(nodeId),
|
||||
schema as BlockIOCredentialsSubSchema,
|
||||
const credentialProvider = toDisplayName(
|
||||
getCredentialProviderFromSchema(
|
||||
useNodeStore.getState().getHardCodedValues(nodeId),
|
||||
schema as BlockIOCredentialsSubSchema,
|
||||
) ?? "",
|
||||
);
|
||||
const credentialProvider = provider
|
||||
? `${toDisplayName(provider)} credential`
|
||||
: "credential";
|
||||
|
||||
const updatedUiSchema = updateUiOption(uiSchema, {
|
||||
showHandles: false,
|
||||
|
||||
@@ -20,15 +20,11 @@ function hasHITLBlocks(graph: GraphModel | LibraryAgent | Graph): boolean {
|
||||
if ("has_human_in_the_loop" in graph) {
|
||||
return !!graph.has_human_in_the_loop;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasSensitiveActionBlocks(
|
||||
graph: GraphModel | LibraryAgent | Graph,
|
||||
): boolean {
|
||||
if ("has_sensitive_action" in graph) {
|
||||
return !!graph.has_sensitive_action;
|
||||
if (isLibraryAgent(graph)) {
|
||||
return graph.settings?.human_in_the_loop_safe_mode !== null;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -44,9 +40,7 @@ export function useAgentSafeMode(graph: GraphModel | LibraryAgent | Graph) {
|
||||
|
||||
const graphId = getGraphId(graph);
|
||||
const isAgent = isLibraryAgent(graph);
|
||||
const showHITLToggle = hasHITLBlocks(graph);
|
||||
const showSensitiveActionToggle = hasSensitiveActionBlocks(graph);
|
||||
const shouldShowToggle = showHITLToggle || showSensitiveActionToggle;
|
||||
const shouldShowToggle = hasHITLBlocks(graph);
|
||||
|
||||
const { mutateAsync: updateGraphSettings, isPending } =
|
||||
usePatchV1UpdateGraphSettings();
|
||||
@@ -62,37 +56,27 @@ export function useAgentSafeMode(graph: GraphModel | LibraryAgent | Graph) {
|
||||
},
|
||||
);
|
||||
|
||||
const [localHITLSafeMode, setLocalHITLSafeMode] = useState<boolean>(true);
|
||||
const [localSensitiveActionSafeMode, setLocalSensitiveActionSafeMode] =
|
||||
useState<boolean>(false);
|
||||
const [isLocalStateLoaded, setIsLocalStateLoaded] = useState<boolean>(false);
|
||||
const [localSafeMode, setLocalSafeMode] = useState<boolean | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAgent && libraryAgent) {
|
||||
setLocalHITLSafeMode(
|
||||
libraryAgent.settings?.human_in_the_loop_safe_mode ?? true,
|
||||
);
|
||||
setLocalSensitiveActionSafeMode(
|
||||
libraryAgent.settings?.sensitive_action_safe_mode ?? false,
|
||||
);
|
||||
setIsLocalStateLoaded(true);
|
||||
const backendValue = libraryAgent.settings?.human_in_the_loop_safe_mode;
|
||||
if (backendValue !== undefined) {
|
||||
setLocalSafeMode(backendValue);
|
||||
}
|
||||
}
|
||||
}, [isAgent, libraryAgent]);
|
||||
|
||||
const currentHITLSafeMode = isAgent
|
||||
? (graph.settings?.human_in_the_loop_safe_mode ?? true)
|
||||
: localHITLSafeMode;
|
||||
const currentSafeMode = isAgent
|
||||
? graph.settings?.human_in_the_loop_safe_mode
|
||||
: localSafeMode;
|
||||
|
||||
const currentSensitiveActionSafeMode = isAgent
|
||||
? (graph.settings?.sensitive_action_safe_mode ?? false)
|
||||
: localSensitiveActionSafeMode;
|
||||
const isStateUndetermined = isAgent
|
||||
? graph.settings?.human_in_the_loop_safe_mode == null
|
||||
: isLoading || localSafeMode === null;
|
||||
|
||||
const isHITLStateUndetermined = isAgent
|
||||
? false
|
||||
: isLoading || !isLocalStateLoaded;
|
||||
|
||||
const handleHITLToggle = useCallback(async () => {
|
||||
const newSafeMode = !currentHITLSafeMode;
|
||||
const handleToggle = useCallback(async () => {
|
||||
const newSafeMode = !currentSafeMode;
|
||||
|
||||
try {
|
||||
await updateGraphSettings({
|
||||
@@ -101,7 +85,7 @@ export function useAgentSafeMode(graph: GraphModel | LibraryAgent | Graph) {
|
||||
});
|
||||
|
||||
if (!isAgent) {
|
||||
setLocalHITLSafeMode(newSafeMode);
|
||||
setLocalSafeMode(newSafeMode);
|
||||
}
|
||||
|
||||
if (isAgent) {
|
||||
@@ -117,62 +101,37 @@ export function useAgentSafeMode(graph: GraphModel | LibraryAgent | Graph) {
|
||||
queryClient.invalidateQueries({ queryKey: ["v2", "executions"] });
|
||||
|
||||
toast({
|
||||
title: `HITL safe mode ${newSafeMode ? "enabled" : "disabled"}`,
|
||||
title: `Safe mode ${newSafeMode ? "enabled" : "disabled"}`,
|
||||
description: newSafeMode
|
||||
? "Human-in-the-loop blocks will require manual review"
|
||||
: "Human-in-the-loop blocks will proceed automatically",
|
||||
duration: 2000,
|
||||
});
|
||||
} catch (error) {
|
||||
handleToggleError(error, isAgent, toast);
|
||||
}
|
||||
}, [
|
||||
currentHITLSafeMode,
|
||||
graphId,
|
||||
isAgent,
|
||||
graph.id,
|
||||
updateGraphSettings,
|
||||
queryClient,
|
||||
toast,
|
||||
]);
|
||||
const isNotFoundError =
|
||||
error instanceof Error &&
|
||||
(error.message.includes("404") || error.message.includes("not found"));
|
||||
|
||||
const handleSensitiveActionToggle = useCallback(async () => {
|
||||
const newSafeMode = !currentSensitiveActionSafeMode;
|
||||
|
||||
try {
|
||||
await updateGraphSettings({
|
||||
graphId,
|
||||
data: { sensitive_action_safe_mode: newSafeMode },
|
||||
});
|
||||
|
||||
if (!isAgent) {
|
||||
setLocalSensitiveActionSafeMode(newSafeMode);
|
||||
}
|
||||
|
||||
if (isAgent) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetLibraryAgentQueryOptions(graph.id.toString())
|
||||
.queryKey,
|
||||
if (!isAgent && isNotFoundError) {
|
||||
toast({
|
||||
title: "Safe mode not available",
|
||||
description:
|
||||
"To configure safe mode, please save this graph to your library first.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} else {
|
||||
toast({
|
||||
title: "Failed to update safe mode",
|
||||
description:
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["v1", "graphs", graphId, "executions"],
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["v2", "executions"] });
|
||||
|
||||
toast({
|
||||
title: `Sensitive action safe mode ${newSafeMode ? "enabled" : "disabled"}`,
|
||||
description: newSafeMode
|
||||
? "Sensitive action blocks will require manual review"
|
||||
: "Sensitive action blocks will proceed automatically",
|
||||
duration: 2000,
|
||||
});
|
||||
} catch (error) {
|
||||
handleToggleError(error, isAgent, toast);
|
||||
}
|
||||
}, [
|
||||
currentSensitiveActionSafeMode,
|
||||
currentSafeMode,
|
||||
graphId,
|
||||
isAgent,
|
||||
graph.id,
|
||||
@@ -182,53 +141,11 @@ export function useAgentSafeMode(graph: GraphModel | LibraryAgent | Graph) {
|
||||
]);
|
||||
|
||||
return {
|
||||
// HITL safe mode
|
||||
currentHITLSafeMode,
|
||||
showHITLToggle,
|
||||
isHITLStateUndetermined,
|
||||
handleHITLToggle,
|
||||
|
||||
// Sensitive action safe mode
|
||||
currentSensitiveActionSafeMode,
|
||||
showSensitiveActionToggle,
|
||||
handleSensitiveActionToggle,
|
||||
|
||||
// General
|
||||
currentSafeMode,
|
||||
isPending,
|
||||
shouldShowToggle,
|
||||
|
||||
// Backwards compatibility
|
||||
currentSafeMode: currentHITLSafeMode,
|
||||
isStateUndetermined: isHITLStateUndetermined,
|
||||
handleToggle: handleHITLToggle,
|
||||
hasHITLBlocks: showHITLToggle,
|
||||
isStateUndetermined,
|
||||
handleToggle,
|
||||
hasHITLBlocks: shouldShowToggle,
|
||||
};
|
||||
}
|
||||
|
||||
function handleToggleError(
|
||||
error: unknown,
|
||||
isAgent: boolean,
|
||||
toast: ReturnType<typeof useToast>["toast"],
|
||||
) {
|
||||
const isNotFoundError =
|
||||
error instanceof Error &&
|
||||
(error.message.includes("404") || error.message.includes("not found"));
|
||||
|
||||
if (!isAgent && isNotFoundError) {
|
||||
toast({
|
||||
title: "Safe mode not available",
|
||||
description:
|
||||
"To configure safe mode, please save this graph to your library first.",
|
||||
variant: "destructive",
|
||||
});
|
||||
} else {
|
||||
toast({
|
||||
title: "Failed to update safe mode",
|
||||
description:
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "An unexpected error occurred.",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import isEqual from "lodash/isEqual";
|
||||
export function cleanNode(node: CustomNode) {
|
||||
return {
|
||||
id: node.id,
|
||||
// Note: position is intentionally excluded to prevent draft saves when dragging nodes
|
||||
position: node.position,
|
||||
data: {
|
||||
hardcodedValues: node.data.hardcodedValues,
|
||||
title: node.data.title,
|
||||
|
||||
@@ -10,7 +10,6 @@ export enum Key {
|
||||
LIBRARY_AGENTS_CACHE = "library-agents-cache",
|
||||
CHAT_SESSION_ID = "chat_session_id",
|
||||
COOKIE_CONSENT = "autogpt_cookie_consent",
|
||||
AI_AGENT_SAFETY_POPUP_SHOWN = "ai-agent-safety-popup-shown",
|
||||
}
|
||||
|
||||
function get(key: Key) {
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = E203
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
*.pyc,
|
||||
.env
|
||||
venv*/*,
|
||||
.venv/*,
|
||||
reports/*,
|
||||
dist/*,
|
||||
data/*,
|
||||
.env,
|
||||
venv*,
|
||||
.venv,
|
||||
reports,
|
||||
dist,
|
||||
data,
|
||||
.benchmark_workspaces,
|
||||
.autogpt,
|
||||
|
||||
291
classic/CLAUDE.md
Normal file
291
classic/CLAUDE.md
Normal file
@@ -0,0 +1,291 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
AutoGPT Classic is an experimental, **unsupported** project demonstrating autonomous GPT-4 operation. Dependencies will not be updated, and the codebase contains known vulnerabilities. This is preserved for educational/historical purposes.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
```
|
||||
classic/
|
||||
├── pyproject.toml # Single consolidated Poetry project
|
||||
├── poetry.lock # Single lock file
|
||||
├── forge/
|
||||
│ └── forge/ # Core agent framework package
|
||||
├── original_autogpt/
|
||||
│ └── autogpt/ # AutoGPT agent package
|
||||
├── direct_benchmark/
|
||||
│ └── direct_benchmark/ # Benchmark harness package
|
||||
└── benchmark/ # Challenge definitions (data, not code)
|
||||
```
|
||||
|
||||
All packages are managed by a single `pyproject.toml` at the classic/ root.
|
||||
|
||||
## Common Commands
|
||||
|
||||
### Setup & Install
|
||||
```bash
|
||||
# Install everything from classic/ directory
|
||||
cd classic
|
||||
poetry install
|
||||
```
|
||||
|
||||
### Running Agents
|
||||
```bash
|
||||
# Run forge agent
|
||||
poetry run python -m forge
|
||||
|
||||
# Run original autogpt server
|
||||
poetry run serve --debug
|
||||
|
||||
# Run autogpt CLI
|
||||
poetry run autogpt
|
||||
```
|
||||
|
||||
Agents run on `http://localhost:8000` by default.
|
||||
|
||||
### Benchmarking
|
||||
```bash
|
||||
# Run benchmarks
|
||||
poetry run direct-benchmark run
|
||||
|
||||
# Run specific strategies and models
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot,rewoo \
|
||||
--models claude \
|
||||
--parallel 4
|
||||
|
||||
# Run a single test
|
||||
poetry run direct-benchmark run --tests ReadFile
|
||||
|
||||
# List available commands
|
||||
poetry run direct-benchmark --help
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
poetry run pytest # All tests
|
||||
poetry run pytest forge/tests/ # Forge tests only
|
||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
||||
poetry run pytest -k test_name # Single test by name
|
||||
poetry run pytest path/to/test.py # Specific test file
|
||||
poetry run pytest --cov # With coverage
|
||||
```
|
||||
|
||||
### Linting & Formatting
|
||||
|
||||
Run from the classic/ directory:
|
||||
|
||||
```bash
|
||||
# Format everything (recommended to run together)
|
||||
poetry run black . && poetry run isort .
|
||||
|
||||
# Check formatting (CI-style, no changes)
|
||||
poetry run black --check . && poetry run isort --check-only .
|
||||
|
||||
# Lint
|
||||
poetry run flake8 # Style linting
|
||||
|
||||
# Type check
|
||||
poetry run pyright # Type checking (some errors are expected in infrastructure code)
|
||||
```
|
||||
|
||||
Note: Always run linters over the entire directory, not specific files, for best results.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Forge (Core Framework)
|
||||
The `forge` package is the foundation that other components depend on:
|
||||
- `forge/agent/` - Agent implementation and protocols
|
||||
- `forge/llm/` - Multi-provider LLM integrations (OpenAI, Anthropic, Groq, LiteLLM)
|
||||
- `forge/components/` - Reusable agent components
|
||||
- `forge/file_storage/` - File system abstraction
|
||||
- `forge/config/` - Configuration management
|
||||
|
||||
### Original AutoGPT
|
||||
- `original_autogpt/autogpt/app/` - CLI application entry points
|
||||
- `original_autogpt/autogpt/agents/` - Agent implementations
|
||||
- `original_autogpt/autogpt/agent_factory/` - Agent creation logic
|
||||
|
||||
### Direct Benchmark
|
||||
Benchmark harness for testing agent performance:
|
||||
- `direct_benchmark/direct_benchmark/` - CLI and harness code
|
||||
- `benchmark/agbenchmark/challenges/` - Test cases organized by category (code, retrieval, data, etc.)
|
||||
- Reports generated in `direct_benchmark/reports/`
|
||||
|
||||
### Package Structure
|
||||
All three packages are included in a single Poetry project. Imports are fully qualified:
|
||||
- `from forge.agent.base import BaseAgent`
|
||||
- `from autogpt.agents.agent import Agent`
|
||||
- `from direct_benchmark.harness import BenchmarkHarness`
|
||||
|
||||
## Code Style
|
||||
|
||||
- Python 3.12 target
|
||||
- Line length: 88 characters (Black default)
|
||||
- Black for formatting, isort for imports (profile="black")
|
||||
- Type hints with Pyright checking
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
- Async support via pytest-asyncio
|
||||
- Fixtures defined in `conftest.py` files provide: `tmp_project_root`, `storage`, `config`, `llm_provider`, `agent`
|
||||
- Tests requiring API keys (OPENAI_API_KEY, ANTHROPIC_API_KEY) will skip if not set
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Copy `.env.example` to `.env` in the relevant directory and add your API keys:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your OPENAI_API_KEY, etc.
|
||||
```
|
||||
|
||||
## Workspaces
|
||||
|
||||
Agents operate within a **workspace** - a directory containing all agent data and files. The workspace root defaults to the current working directory.
|
||||
|
||||
### Workspace Structure
|
||||
|
||||
```
|
||||
{workspace}/
|
||||
├── .autogpt/
|
||||
│ ├── autogpt.yaml # Workspace-level permissions
|
||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
||||
│ └── agents/
|
||||
│ └── AutoGPT-{agent_id}/
|
||||
│ ├── state.json # Agent profile, directives, action history
|
||||
│ ├── permissions.yaml # Agent-specific permission overrides
|
||||
│ └── workspace/ # Agent's sandboxed working directory
|
||||
```
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Multiple agents** can coexist in the same workspace (each gets its own subdirectory)
|
||||
- **File access** is sandboxed to the agent's `workspace/` directory by default
|
||||
- **State persistence** - agent state saves to `state.json` and survives across sessions
|
||||
- **Storage backends** - supports local filesystem, S3, and GCS (via `FILE_STORAGE_BACKEND` env var)
|
||||
|
||||
### Specifying a Workspace
|
||||
|
||||
```bash
|
||||
# Default: uses current directory
|
||||
cd /path/to/my/project && poetry run autogpt
|
||||
|
||||
# Or specify explicitly via CLI (if supported)
|
||||
poetry run autogpt --workspace /path/to/workspace
|
||||
```
|
||||
|
||||
## Settings Location
|
||||
|
||||
Configuration uses a **layered system** with three levels (in order of precedence):
|
||||
|
||||
### 1. Environment Variables (Global)
|
||||
|
||||
Loaded from `.env` file in the working directory:
|
||||
|
||||
```bash
|
||||
# Required
|
||||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# Optional LLM settings
|
||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
# Optional search providers (for web search component)
|
||||
TAVILY_API_KEY=tvly-...
|
||||
SERPER_API_KEY=...
|
||||
GOOGLE_API_KEY=...
|
||||
GOOGLE_CUSTOM_SEARCH_ENGINE_ID=...
|
||||
|
||||
# Optional infrastructure
|
||||
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR
|
||||
DATABASE_STRING=sqlite:///agent.db # Agent Protocol database
|
||||
PORT=8000 # Server port
|
||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
||||
```
|
||||
|
||||
### 2. Workspace Settings (`{workspace}/.autogpt/autogpt.yaml`)
|
||||
|
||||
Workspace-wide permissions that apply to **all agents** in this workspace:
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- read_file({workspace}/**)
|
||||
- write_to_file({workspace}/**)
|
||||
- list_folder({workspace}/**)
|
||||
- web_search(*)
|
||||
|
||||
deny:
|
||||
- read_file(**.env)
|
||||
- read_file(**.env.*)
|
||||
- read_file(**.key)
|
||||
- read_file(**.pem)
|
||||
- execute_shell(rm -rf:*)
|
||||
- execute_shell(sudo:*)
|
||||
```
|
||||
|
||||
Auto-generated with sensible defaults if missing.
|
||||
|
||||
### 3. Agent Settings (`{workspace}/.autogpt/agents/{id}/permissions.yaml`)
|
||||
|
||||
Agent-specific permission overrides:
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- execute_python(*)
|
||||
- web_search(*)
|
||||
|
||||
deny:
|
||||
- execute_shell(*)
|
||||
```
|
||||
|
||||
## Permissions
|
||||
|
||||
The permission system uses **pattern matching** with a **first-match-wins** evaluation order.
|
||||
|
||||
### Permission Check Order
|
||||
|
||||
1. Agent deny list → **Block**
|
||||
2. Workspace deny list → **Block**
|
||||
3. Agent allow list → **Allow**
|
||||
4. Workspace allow list → **Allow**
|
||||
5. Session denied list → **Block** (commands denied during this session)
|
||||
6. **Prompt user** → Interactive approval (if in interactive mode)
|
||||
|
||||
### Pattern Syntax
|
||||
|
||||
Format: `command_name(glob_pattern)`
|
||||
|
||||
| Pattern | Description |
|
||||
|---------|-------------|
|
||||
| `read_file({workspace}/**)` | Read any file in workspace (recursive) |
|
||||
| `write_to_file({workspace}/*.txt)` | Write only .txt files in workspace root |
|
||||
| `execute_shell(python:**)` | Execute Python commands only |
|
||||
| `execute_shell(git:*)` | Execute any git command |
|
||||
| `web_search(*)` | Allow all web searches |
|
||||
|
||||
Special tokens:
|
||||
- `{workspace}` - Replaced with actual workspace path
|
||||
- `**` - Matches any path including `/`
|
||||
- `*` - Matches any characters except `/`
|
||||
|
||||
### Interactive Approval Scopes
|
||||
|
||||
When prompted for permission, users can choose:
|
||||
|
||||
| Scope | Effect |
|
||||
|-------|--------|
|
||||
| **Once** | Allow this one time only (not saved) |
|
||||
| **Agent** | Always allow for this agent (saves to agent `permissions.yaml`) |
|
||||
| **Workspace** | Always allow for all agents (saves to `autogpt.yaml`) |
|
||||
| **Deny** | Deny this command (saves to appropriate deny list) |
|
||||
|
||||
### Default Security
|
||||
|
||||
Out of the box, the following are **denied by default**:
|
||||
- Reading sensitive files (`.env`, `.key`, `.pem`)
|
||||
- Destructive shell commands (`rm -rf`, `sudo`)
|
||||
- Operations outside the workspace directory
|
||||
@@ -1,182 +0,0 @@
|
||||
## CLI Documentation
|
||||
|
||||
This document describes how to interact with the project's CLI (Command Line Interface). It includes the types of outputs you can expect from each command. Note that the `agents stop` command will terminate any process running on port 8000.
|
||||
|
||||
### 1. Entry Point for the CLI
|
||||
|
||||
Running the `./run` command without any parameters will display the help message, which provides a list of available commands and options. Additionally, you can append `--help` to any command to view help information specific to that command.
|
||||
|
||||
```sh
|
||||
./run
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Usage: cli.py [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
agent Commands to create, start and stop agents
|
||||
benchmark Commands to start the benchmark and list tests and categories
|
||||
setup Installs dependencies needed for your system.
|
||||
```
|
||||
|
||||
If you need assistance with any command, simply add the `--help` parameter to the end of your command, like so:
|
||||
|
||||
```sh
|
||||
./run COMMAND --help
|
||||
```
|
||||
|
||||
This will display a detailed help message regarding that specific command, including a list of any additional options and arguments it accepts.
|
||||
|
||||
### 2. Setup Command
|
||||
|
||||
```sh
|
||||
./run setup
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Setup initiated
|
||||
Installation has been completed.
|
||||
```
|
||||
|
||||
This command initializes the setup of the project.
|
||||
|
||||
### 3. Agents Commands
|
||||
|
||||
**a. List All Agents**
|
||||
|
||||
```sh
|
||||
./run agent list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available agents: 🤖
|
||||
🐙 forge
|
||||
🐙 autogpt
|
||||
```
|
||||
|
||||
Lists all the available agents.
|
||||
|
||||
**b. Create a New Agent**
|
||||
|
||||
```sh
|
||||
./run agent create my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
🎉 New agent 'my_agent' created and switched to the new directory in agents folder.
|
||||
```
|
||||
|
||||
Creates a new agent named 'my_agent'.
|
||||
|
||||
**c. Start an Agent**
|
||||
|
||||
```sh
|
||||
./run agent start my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
... (ASCII Art representing the agent startup)
|
||||
[Date and Time] [forge.sdk.db] [DEBUG] 🐛 Initializing AgentDB with database_string: sqlite:///agent.db
|
||||
[Date and Time] [forge.sdk.agent] [INFO] 📝 Agent server starting on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
Starts the 'my_agent' and displays startup ASCII art and logs.
|
||||
|
||||
**d. Stop an Agent**
|
||||
|
||||
```sh
|
||||
./run agent stop
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Agent stopped
|
||||
```
|
||||
|
||||
Stops the running agent.
|
||||
|
||||
### 4. Benchmark Commands
|
||||
|
||||
**a. List Benchmark Categories**
|
||||
|
||||
```sh
|
||||
./run benchmark categories list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available categories: 📚
|
||||
📖 code
|
||||
📖 safety
|
||||
📖 memory
|
||||
... (and so on)
|
||||
```
|
||||
|
||||
Lists all available benchmark categories.
|
||||
|
||||
**b. List Benchmark Tests**
|
||||
|
||||
```sh
|
||||
./run benchmark tests list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available tests: 📚
|
||||
📖 interface
|
||||
🔬 Search - TestSearch
|
||||
🔬 Write File - TestWriteFile
|
||||
... (and so on)
|
||||
```
|
||||
|
||||
Lists all available benchmark tests.
|
||||
|
||||
**c. Show Details of a Benchmark Test**
|
||||
|
||||
```sh
|
||||
./run benchmark tests details TestWriteFile
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
TestWriteFile
|
||||
-------------
|
||||
|
||||
Category: interface
|
||||
Task: Write the word 'Washington' to a .txt file
|
||||
... (and other details)
|
||||
```
|
||||
|
||||
Displays the details of the 'TestWriteFile' benchmark test.
|
||||
|
||||
**d. Start Benchmark for the Agent**
|
||||
|
||||
```sh
|
||||
./run benchmark start my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
(more details about the testing process shown whilst the test are running)
|
||||
============= 13 failed, 1 passed in 0.97s ============...
|
||||
```
|
||||
|
||||
Displays the results of the benchmark tests on 'my_agent'.
|
||||
@@ -2,7 +2,7 @@
|
||||
ARG BUILD_TYPE=dev
|
||||
|
||||
# Use an official Python base image from the Docker Hub
|
||||
FROM python:3.10-slim AS autogpt-base
|
||||
FROM python:3.12-slim AS autogpt-base
|
||||
|
||||
# Install browsers
|
||||
RUN apt-get update && apt-get install -y \
|
||||
@@ -34,9 +34,6 @@ COPY original_autogpt/pyproject.toml original_autogpt/poetry.lock ./
|
||||
# Include forge so it can be used as a path dependency
|
||||
COPY forge/ ../forge
|
||||
|
||||
# Include frontend
|
||||
COPY frontend/ ../frontend
|
||||
|
||||
# Set the entrypoint
|
||||
ENTRYPOINT ["poetry", "run", "autogpt"]
|
||||
CMD []
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
# Quickstart Guide
|
||||
|
||||
> For the complete getting started [tutorial series](https://aiedge.medium.com/autogpt-forge-e3de53cc58ec) <- click here
|
||||
|
||||
Welcome to the Quickstart Guide! This guide will walk you through setting up, building, and running your own AutoGPT agent. Whether you're a seasoned AI developer or just starting out, this guide will provide you with the steps to jumpstart your journey in AI development with AutoGPT.
|
||||
|
||||
## System Requirements
|
||||
|
||||
This project supports Linux (Debian-based), Mac, and Windows Subsystem for Linux (WSL). If you use a Windows system, you must install WSL. You can find the installation instructions for WSL [here](https://learn.microsoft.com/en-us/windows/wsl/).
|
||||
|
||||
|
||||
## Getting Setup
|
||||
1. **Fork the Repository**
|
||||
To fork the repository, follow these steps:
|
||||
- Navigate to the main page of the repository.
|
||||
|
||||

|
||||
- In the top-right corner of the page, click Fork.
|
||||
|
||||

|
||||
- On the next page, select your GitHub account to create the fork.
|
||||
- Wait for the forking process to complete. You now have a copy of the repository in your GitHub account.
|
||||
|
||||
2. **Clone the Repository**
|
||||
To clone the repository, you need to have Git installed on your system. If you don't have Git installed, download it from [here](https://git-scm.com/downloads). Once you have Git installed, follow these steps:
|
||||
- Open your terminal.
|
||||
- Navigate to the directory where you want to clone the repository.
|
||||
- Run the git clone command for the fork you just created
|
||||
|
||||

|
||||
|
||||
- Then open your project in your ide
|
||||
|
||||

|
||||
|
||||
4. **Setup the Project**
|
||||
Next, we need to set up the required dependencies. We have a tool to help you perform all the tasks on the repo.
|
||||
It can be accessed by running the `run` command by typing `./run` in the terminal.
|
||||
|
||||
The first command you need to use is `./run setup.` This will guide you through setting up your system.
|
||||
Initially, you will get instructions for installing Flutter and Chrome and setting up your GitHub access token like the following image:
|
||||
|
||||

|
||||
|
||||
### For Windows Users
|
||||
|
||||
If you're a Windows user and experience issues after installing WSL, follow the steps below to resolve them.
|
||||
|
||||
#### Update WSL
|
||||
Run the following command in Powershell or Command Prompt:
|
||||
1. Enable the optional WSL and Virtual Machine Platform components.
|
||||
2. Download and install the latest Linux kernel.
|
||||
3. Set WSL 2 as the default.
|
||||
4. Download and install the Ubuntu Linux distribution (a reboot may be required).
|
||||
|
||||
```shell
|
||||
wsl --install
|
||||
```
|
||||
|
||||
For more detailed information and additional steps, refer to [Microsoft's WSL Setup Environment Documentation](https://learn.microsoft.com/en-us/windows/wsl/setup/environment).
|
||||
|
||||
#### Resolve FileNotFoundError or "No such file or directory" Errors
|
||||
When you run `./run setup`, if you encounter errors like `No such file or directory` or `FileNotFoundError`, it might be because Windows-style line endings (CRLF - Carriage Return Line Feed) are not compatible with Unix/Linux style line endings (LF - Line Feed).
|
||||
|
||||
To resolve this, you can use the `dos2unix` utility to convert the line endings in your script from CRLF to LF. Here’s how to install and run `dos2unix` on the script:
|
||||
|
||||
```shell
|
||||
sudo apt update
|
||||
sudo apt install dos2unix
|
||||
dos2unix ./run
|
||||
```
|
||||
|
||||
After executing the above commands, running `./run setup` should work successfully.
|
||||
|
||||
#### Store Project Files within the WSL File System
|
||||
If you continue to experience issues, consider storing your project files within the WSL file system instead of the Windows file system. This method avoids path translations and permissions issues and provides a more consistent development environment.
|
||||
|
||||
You can keep running the command to get feedback on where you are up to with your setup.
|
||||
When setup has been completed, the command will return an output like this:
|
||||
|
||||

|
||||
|
||||
## Creating Your Agent
|
||||
|
||||
After completing the setup, the next step is to create your agent template.
|
||||
Execute the command `./run agent create YOUR_AGENT_NAME`, where `YOUR_AGENT_NAME` should be replaced with your chosen name.
|
||||
|
||||
Tips for naming your agent:
|
||||
* Give it its own unique name, or name it after yourself
|
||||
* Include an important aspect of your agent in the name, such as its purpose
|
||||
|
||||
Examples: `SwiftyosAssistant`, `PwutsPRAgent`, `MySuperAgent`
|
||||
|
||||

|
||||
|
||||
## Running your Agent
|
||||
|
||||
Your agent can be started using the command: `./run agent start YOUR_AGENT_NAME`
|
||||
|
||||
This starts the agent on the URL: `http://localhost:8000/`
|
||||
|
||||

|
||||
|
||||
The front end can be accessed from `http://localhost:8000/`; first, you must log in using either a Google account or your GitHub account.
|
||||
|
||||

|
||||
|
||||
Upon logging in, you will get a page that looks something like this: your task history down the left-hand side of the page, and the 'chat' window to send tasks to your agent.
|
||||
|
||||

|
||||
|
||||
When you have finished with your agent or just need to restart it, use Ctl-C to end the session. Then, you can re-run the start command.
|
||||
|
||||
If you are having issues and want to ensure the agent has been stopped, there is a `./run agent stop` command, which will kill the process using port 8000, which should be the agent.
|
||||
|
||||
## Benchmarking your Agent
|
||||
|
||||
The benchmarking system can also be accessed using the CLI too:
|
||||
|
||||
```bash
|
||||
agpt % ./run benchmark
|
||||
Usage: cli.py benchmark [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Commands to start the benchmark and list tests and categories
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
categories Benchmark categories group command
|
||||
start Starts the benchmark command
|
||||
tests Benchmark tests group command
|
||||
agpt % ./run benchmark categories
|
||||
Usage: cli.py benchmark categories [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Benchmark categories group command
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
list List benchmark categories command
|
||||
agpt % ./run benchmark tests
|
||||
Usage: cli.py benchmark tests [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Benchmark tests group command
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
details Benchmark test details command
|
||||
list List benchmark tests command
|
||||
```
|
||||
|
||||
The benchmark has been split into different categories of skills you can test your agent on. You can see what categories are available with
|
||||
```bash
|
||||
./run benchmark categories list
|
||||
# And what tests are available with
|
||||
./run benchmark tests list
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
Finally, you can run the benchmark with
|
||||
|
||||
```bash
|
||||
./run benchmark start YOUR_AGENT_NAME
|
||||
|
||||
```
|
||||
|
||||
>
|
||||
@@ -4,7 +4,7 @@ AutoGPT Classic was an experimental project to demonstrate autonomous GPT-4 oper
|
||||
|
||||
## Project Status
|
||||
|
||||
⚠️ **This project is unsupported, and dependencies will not be updated. It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform)**
|
||||
**This project is unsupported, and dependencies will not be updated.** It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform).
|
||||
|
||||
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
||||
|
||||
@@ -16,37 +16,171 @@ AutoGPT Classic was one of the first implementations of autonomous AI agents - A
|
||||
- Learn from the results and adjust its approach
|
||||
- Chain multiple actions together to achieve an objective
|
||||
|
||||
## Key Features
|
||||
|
||||
- 🔄 Autonomous task chaining
|
||||
- 🛠 Tool and API integration capabilities
|
||||
- 💾 Memory management for context retention
|
||||
- 🔍 Web browsing and information gathering
|
||||
- 📝 File operations and content creation
|
||||
- 🔄 Self-prompting and task breakdown
|
||||
|
||||
## Structure
|
||||
|
||||
The project is organized into several key components:
|
||||
- `/benchmark` - Performance testing tools
|
||||
- `/forge` - Core autonomous agent framework
|
||||
- `/frontend` - User interface components
|
||||
- `/original_autogpt` - Original implementation
|
||||
```
|
||||
classic/
|
||||
├── pyproject.toml # Single consolidated Poetry project
|
||||
├── poetry.lock # Single lock file
|
||||
├── forge/ # Core autonomous agent framework
|
||||
├── original_autogpt/ # Original implementation
|
||||
├── direct_benchmark/ # Benchmark harness
|
||||
└── benchmark/ # Challenge definitions (data)
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
While this project is no longer actively maintained, you can still explore the codebase:
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- [Poetry](https://python-poetry.org/docs/#installation)
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||
cd classic
|
||||
|
||||
# Install everything
|
||||
poetry install
|
||||
```
|
||||
|
||||
2. Review the documentation:
|
||||
- For reference, see the [documentation](https://docs.agpt.co). You can browse at the same point in time as this commit so the docs don't change.
|
||||
- Check `CLI-USAGE.md` for command-line interface details
|
||||
- Refer to `TROUBLESHOOTING.md` for common issues
|
||||
### Configuration
|
||||
|
||||
Configuration uses a layered system:
|
||||
|
||||
1. **Environment variables** (`.env` file)
|
||||
2. **Workspace settings** (`.autogpt/autogpt.yaml`)
|
||||
3. **Agent settings** (`.autogpt/agents/{id}/permissions.yaml`)
|
||||
|
||||
Copy the example environment file and add your API keys:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Key environment variables:
|
||||
```bash
|
||||
# Required
|
||||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# Optional LLM settings
|
||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
||||
|
||||
# Optional search providers
|
||||
TAVILY_API_KEY=tvly-...
|
||||
SERPER_API_KEY=...
|
||||
|
||||
# Optional infrastructure
|
||||
LOG_LEVEL=DEBUG
|
||||
PORT=8000
|
||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
||||
```
|
||||
|
||||
### Running
|
||||
|
||||
All commands run from the `classic/` directory:
|
||||
|
||||
```bash
|
||||
# Run forge agent
|
||||
poetry run python -m forge
|
||||
|
||||
# Run original autogpt server
|
||||
poetry run serve --debug
|
||||
|
||||
# Run autogpt CLI
|
||||
poetry run autogpt
|
||||
```
|
||||
|
||||
Agents run on `http://localhost:8000` by default.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
```bash
|
||||
poetry run direct-benchmark run
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
poetry run pytest # All tests
|
||||
poetry run pytest forge/tests/ # Forge tests only
|
||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
||||
```
|
||||
|
||||
## Workspaces
|
||||
|
||||
Agents operate within a **workspace** directory that contains all agent data and files:
|
||||
|
||||
```
|
||||
{workspace}/
|
||||
├── .autogpt/
|
||||
│ ├── autogpt.yaml # Workspace-level permissions
|
||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
||||
│ └── agents/
|
||||
│ └── AutoGPT-{agent_id}/
|
||||
│ ├── state.json # Agent profile, directives, history
|
||||
│ ├── permissions.yaml # Agent-specific permissions
|
||||
│ └── workspace/ # Agent's sandboxed working directory
|
||||
```
|
||||
|
||||
- The workspace defaults to the current working directory
|
||||
- Multiple agents can coexist in the same workspace
|
||||
- Agent file access is sandboxed to their `workspace/` subdirectory
|
||||
- State persists across sessions via `state.json`
|
||||
|
||||
## Permissions
|
||||
|
||||
AutoGPT uses a **layered permission system** with pattern matching:
|
||||
|
||||
### Permission Files
|
||||
|
||||
| File | Scope | Location |
|
||||
|------|-------|----------|
|
||||
| `autogpt.yaml` | All agents in workspace | `.autogpt/autogpt.yaml` |
|
||||
| `permissions.yaml` | Single agent | `.autogpt/agents/{id}/permissions.yaml` |
|
||||
|
||||
### Permission Format
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- read_file({workspace}/**) # Read any file in workspace
|
||||
- write_to_file({workspace}/**) # Write any file in workspace
|
||||
- web_search(*) # All web searches
|
||||
|
||||
deny:
|
||||
- read_file(**.env) # Block .env files
|
||||
- execute_shell(sudo:*) # Block sudo commands
|
||||
```
|
||||
|
||||
### Check Order (First Match Wins)
|
||||
|
||||
1. Agent deny → Block
|
||||
2. Workspace deny → Block
|
||||
3. Agent allow → Allow
|
||||
4. Workspace allow → Allow
|
||||
5. Prompt user → Interactive approval
|
||||
|
||||
### Interactive Approval
|
||||
|
||||
When prompted, users can approve commands with different scopes:
|
||||
- **Once** - Allow this one time only
|
||||
- **Agent** - Always allow for this agent
|
||||
- **Workspace** - Always allow for all agents
|
||||
- **Deny** - Block this command
|
||||
|
||||
### Default Security
|
||||
|
||||
Denied by default:
|
||||
- Sensitive files (`.env`, `.key`, `.pem`)
|
||||
- Destructive commands (`rm -rf`, `sudo`)
|
||||
- Operations outside the workspace
|
||||
|
||||
## Security Notice
|
||||
|
||||
This codebase has **known vulnerabilities** and issues with its dependencies. It will not be updated to new dependencies. Use for educational purposes only.
|
||||
|
||||
## License
|
||||
|
||||
@@ -55,27 +189,3 @@ This project segment is licensed under the MIT License - see the [LICENSE](LICEN
|
||||
## Documentation
|
||||
|
||||
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
||||
You can browse at the same point in time as this commit so the docs don't change.
|
||||
|
||||
## Historical Impact
|
||||
|
||||
AutoGPT Classic played a significant role in advancing the field of autonomous AI agents:
|
||||
- Demonstrated practical implementation of AI autonomy
|
||||
- Inspired numerous derivative projects and research
|
||||
- Contributed to the development of AI agent architectures
|
||||
- Helped identify key challenges in AI autonomy
|
||||
|
||||
## Security Notice
|
||||
|
||||
If you're studying this codebase, please understand this has KNOWN vulnerabilities and issues with its dependencies. It will not be updated to new dependencies.
|
||||
|
||||
## Community & Support
|
||||
|
||||
While active development has concluded:
|
||||
- The codebase remains available for study and reference
|
||||
- Historical discussions can be found in project issues
|
||||
- Related research and developments continue in the broader AI agent community
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Thanks to all contributors who participated in this experimental project and helped advance the field of autonomous AI agents.
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
AGENT_NAME=mini-agi
|
||||
REPORTS_FOLDER="reports/mini-agi"
|
||||
OPENAI_API_KEY="sk-" # for LLM eval
|
||||
BUILD_SKILL_TREE=false # set to true to build the skill tree.
|
||||
@@ -1,12 +0,0 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
reports/,
|
||||
agbenchmark/reports/,
|
||||
174
classic/benchmark/.gitignore
vendored
174
classic/benchmark/.gitignore
vendored
@@ -1,174 +0,0 @@
|
||||
agbenchmark_config/workspace/
|
||||
backend/backend_stdout.txt
|
||||
reports/df*.pkl
|
||||
reports/raw*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
.DS_Store
|
||||
```
|
||||
secrets.json
|
||||
agbenchmark_config/challenges_already_beaten.json
|
||||
agbenchmark_config/challenges/pri_*
|
||||
agbenchmark_config/updates.json
|
||||
agbenchmark_config/reports/*
|
||||
agbenchmark_config/reports/success_rate.json
|
||||
agbenchmark_config/reports/regression_tests.json
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 AutoGPT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,25 +0,0 @@
|
||||
# Auto-GPT Benchmarks
|
||||
|
||||
Built for the purpose of benchmarking the performance of agents regardless of how they work.
|
||||
|
||||
Objectively know how well your agent is performing in categories like code, retrieval, memory, and safety.
|
||||
|
||||
Save time and money while doing it through smart dependencies. The best part? It's all automated.
|
||||
|
||||
## Scores:
|
||||
|
||||
<img width="733" alt="Screenshot 2023-07-25 at 10 35 01 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/98963e0b-18b9-4b17-9a6a-4d3e4418af70">
|
||||
|
||||
## Ranking overall:
|
||||
|
||||
- 1- [Beebot](https://github.com/AutoPackAI/beebot)
|
||||
- 2- [mini-agi](https://github.com/muellerberndt/mini-agi)
|
||||
- 3- [Auto-GPT](https://github.com/Significant-Gravitas/AutoGPT)
|
||||
|
||||
## Detailed results:
|
||||
|
||||
<img width="733" alt="Screenshot 2023-07-25 at 10 42 15 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/39be464c-c842-4437-b28a-07d878542a83">
|
||||
|
||||
[Click here to see the results and the raw data!](https://docs.google.com/spreadsheets/d/1WXm16P2AHNbKpkOI0LYBpcsGG0O7D8HYTG5Uj0PaJjA/edit#gid=203558751)!
|
||||
|
||||
More agents coming soon !
|
||||
@@ -1,69 +0,0 @@
|
||||
## As a user
|
||||
|
||||
1. `pip install auto-gpt-benchmarks`
|
||||
2. Add boilerplate code to run and kill agent
|
||||
3. `agbenchmark`
|
||||
- `--category challenge_category` to run tests in a specific category
|
||||
- `--mock` to only run mock tests if they exists for each test
|
||||
- `--noreg` to skip any tests that have passed in the past. When you run without this flag and a previous challenge that passed fails, it will now not be regression tests
|
||||
4. We call boilerplate code for your agent
|
||||
5. Show pass rate of tests, logs, and any other metrics
|
||||
|
||||
## Contributing
|
||||
|
||||
##### Diagrams: https://whimsical.com/agbenchmark-5n4hXBq1ZGzBwRsK4TVY7x
|
||||
|
||||
### To run the existing mocks
|
||||
|
||||
1. clone the repo `auto-gpt-benchmarks`
|
||||
2. `pip install poetry`
|
||||
3. `poetry shell`
|
||||
4. `poetry install`
|
||||
5. `cp .env_example .env`
|
||||
6. `git submodule update --init --remote --recursive`
|
||||
7. `uvicorn server:app --reload`
|
||||
8. `agbenchmark --mock`
|
||||
Keep config the same and watch the logs :)
|
||||
|
||||
### To run with mini-agi
|
||||
|
||||
1. Navigate to `auto-gpt-benchmarks/agent/mini-agi`
|
||||
2. `pip install -r requirements.txt`
|
||||
3. `cp .env_example .env`, set `PROMPT_USER=false` and add your `OPENAI_API_KEY=`. Sset `MODEL="gpt-3.5-turbo"` if you don't have access to `gpt-4` yet. Also make sure you have Python 3.10^ installed
|
||||
4. set `AGENT_NAME=mini-agi` in `.env` file and where you want your `REPORTS_FOLDER` to be
|
||||
5. Make sure to follow the commands above, and remove mock flag `agbenchmark`
|
||||
|
||||
- To add requirements `poetry add requirement`.
|
||||
|
||||
Feel free to create prs to merge with `main` at will (but also feel free to ask for review) - if you can't send msg in R&D chat for access.
|
||||
|
||||
If you push at any point and break things - it'll happen to everyone - fix it asap. Step 1 is to revert `master` to last working commit
|
||||
|
||||
Let people know what beautiful code you write does, document everything well
|
||||
|
||||
Share your progress :)
|
||||
|
||||
#### Dataset
|
||||
|
||||
Manually created, existing challenges within Auto-Gpt, https://osu-nlp-group.github.io/Mind2Web/
|
||||
|
||||
## How do I add new agents to agbenchmark ?
|
||||
|
||||
Example with smol developer.
|
||||
|
||||
1- Create a github branch with your agent following the same pattern as this example:
|
||||
|
||||
https://github.com/smol-ai/developer/pull/114/files
|
||||
|
||||
2- Create the submodule and the github workflow by following the same pattern as this example:
|
||||
|
||||
https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/pull/48/files
|
||||
|
||||
## How do I run agent in different environments?
|
||||
|
||||
**To just use as the benchmark for your agent**. `pip install` the package and run `agbenchmark`
|
||||
|
||||
**For internal Auto-GPT ci runs**, specify the `AGENT_NAME` you want you use and set the `HOME_ENV`.
|
||||
Ex. `AGENT_NAME=mini-agi`
|
||||
|
||||
**To develop agent alongside benchmark**, you can specify the `AGENT_NAME` you want you use and add as a submodule to the repo
|
||||
@@ -1,352 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
from click_default_group import DefaultGroup
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.logging import configure_logging
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# try:
|
||||
# if os.getenv("HELICONE_API_KEY"):
|
||||
# import helicone # noqa
|
||||
|
||||
# helicone_enabled = True
|
||||
# else:
|
||||
# helicone_enabled = False
|
||||
# except ImportError:
|
||||
# helicone_enabled = False
|
||||
|
||||
|
||||
class InvalidInvocationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BENCHMARK_START_TIME_DT = datetime.now(timezone.utc)
|
||||
BENCHMARK_START_TIME = BENCHMARK_START_TIME_DT.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||
|
||||
|
||||
# if helicone_enabled:
|
||||
# from helicone.lock import HeliconeLockManager
|
||||
|
||||
# HeliconeLockManager.write_custom_property(
|
||||
# "benchmark_start_time", BENCHMARK_START_TIME
|
||||
# )
|
||||
|
||||
|
||||
@click.group(cls=DefaultGroup, default_if_no_args=True)
|
||||
@click.option("--debug", is_flag=True, help="Enable debug output")
|
||||
def cli(
|
||||
debug: bool,
|
||||
) -> Any:
|
||||
configure_logging(logging.DEBUG if debug else logging.INFO)
|
||||
|
||||
|
||||
@cli.command(hidden=True)
|
||||
def start():
|
||||
raise DeprecationWarning(
|
||||
"`agbenchmark start` is deprecated. Use `agbenchmark run` instead."
|
||||
)
|
||||
|
||||
|
||||
@cli.command(default=True)
|
||||
@click.option(
|
||||
"-N", "--attempts", default=1, help="Number of times to run each challenge."
|
||||
)
|
||||
@click.option(
|
||||
"-c",
|
||||
"--category",
|
||||
multiple=True,
|
||||
help="(+) Select a category to run.",
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--skip-category",
|
||||
multiple=True,
|
||||
help="(+) Exclude a category from running.",
|
||||
)
|
||||
@click.option("--test", multiple=True, help="(+) Select a test to run.")
|
||||
@click.option("--maintain", is_flag=True, help="Run only regression tests.")
|
||||
@click.option("--improve", is_flag=True, help="Run only non-regression tests.")
|
||||
@click.option(
|
||||
"--explore",
|
||||
is_flag=True,
|
||||
help="Run only challenges that have never been beaten.",
|
||||
)
|
||||
@click.option(
|
||||
"--no-dep",
|
||||
is_flag=True,
|
||||
help="Run all (selected) challenges, regardless of dependency success/failure.",
|
||||
)
|
||||
@click.option("--cutoff", type=int, help="Override the challenge time limit (seconds).")
|
||||
@click.option("--nc", is_flag=True, help="Disable the challenge time limit.")
|
||||
@click.option("--mock", is_flag=True, help="Run with mock")
|
||||
@click.option("--keep-answers", is_flag=True, help="Keep answers")
|
||||
@click.option(
|
||||
"--backend",
|
||||
is_flag=True,
|
||||
help="Write log output to a file instead of the terminal.",
|
||||
)
|
||||
# @click.argument(
|
||||
# "agent_path",
|
||||
# type=click.Path(exists=True, file_okay=False, path_type=Path),
|
||||
# required=False,
|
||||
# )
|
||||
def run(
|
||||
maintain: bool,
|
||||
improve: bool,
|
||||
explore: bool,
|
||||
mock: bool,
|
||||
no_dep: bool,
|
||||
nc: bool,
|
||||
keep_answers: bool,
|
||||
test: tuple[str],
|
||||
category: tuple[str],
|
||||
skip_category: tuple[str],
|
||||
attempts: int,
|
||||
cutoff: Optional[int] = None,
|
||||
backend: Optional[bool] = False,
|
||||
# agent_path: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the benchmark on the agent in the current directory.
|
||||
|
||||
Options marked with (+) can be specified multiple times, to select multiple items.
|
||||
"""
|
||||
from agbenchmark.main import run_benchmark, validate_args
|
||||
|
||||
agbenchmark_config = AgentBenchmarkConfig.load()
|
||||
logger.debug(f"agbenchmark_config: {agbenchmark_config.agbenchmark_config_dir}")
|
||||
try:
|
||||
validate_args(
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
no_cutoff=nc,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
except InvalidInvocationError as e:
|
||||
logger.error("Error: " + "\n".join(e.args))
|
||||
sys.exit(1)
|
||||
|
||||
original_stdout = sys.stdout # Save the original standard output
|
||||
exit_code = None
|
||||
|
||||
if backend:
|
||||
with open("backend/backend_stdout.txt", "w") as f:
|
||||
sys.stdout = f
|
||||
exit_code = run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
mock=mock,
|
||||
no_dep=no_dep,
|
||||
no_cutoff=nc,
|
||||
keep_answers=keep_answers,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
attempts_per_challenge=attempts,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
|
||||
sys.stdout = original_stdout
|
||||
|
||||
else:
|
||||
exit_code = run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
mock=mock,
|
||||
no_dep=no_dep,
|
||||
no_cutoff=nc,
|
||||
keep_answers=keep_answers,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
attempts_per_challenge=attempts,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--port", type=int, help="Port to run the API on.")
|
||||
def serve(port: Optional[int] = None):
|
||||
"""Serve the benchmark frontend and API on port 8080."""
|
||||
import uvicorn
|
||||
|
||||
from agbenchmark.app import setup_fastapi_app
|
||||
|
||||
config = AgentBenchmarkConfig.load()
|
||||
app = setup_fastapi_app(config)
|
||||
|
||||
# Run the FastAPI application using uvicorn
|
||||
port = port or int(os.getenv("PORT", 8080))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def config():
|
||||
"""Displays info regarding the present AGBenchmark config."""
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
try:
|
||||
config = AgentBenchmarkConfig.load()
|
||||
except FileNotFoundError as e:
|
||||
click.echo(e, err=True)
|
||||
return 1
|
||||
|
||||
pretty_print_model(config, include_header=False)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def challenge():
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@challenge.command("list")
|
||||
@click.option(
|
||||
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
|
||||
)
|
||||
@click.option(
|
||||
"--names", "only_names", is_flag=True, help="List only the challenge names."
|
||||
)
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
|
||||
"""Lists [available|all] challenges."""
|
||||
import json
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.data_types import Category, DifficultyLevel
|
||||
from .utils.utils import sorted_by_enum_index
|
||||
|
||||
DIFFICULTY_COLORS = {
|
||||
difficulty: color
|
||||
for difficulty, color in zip(
|
||||
DifficultyLevel,
|
||||
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
|
||||
)
|
||||
}
|
||||
CATEGORY_COLORS = {
|
||||
category: f"bright_{color}"
|
||||
for category, color in zip(
|
||||
Category,
|
||||
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
|
||||
)
|
||||
}
|
||||
|
||||
# Load challenges
|
||||
challenges = filter(
|
||||
lambda c: c.info.available or include_unavailable,
|
||||
[
|
||||
*load_builtin_challenges(),
|
||||
*load_webarena_challenges(skip_unavailable=False),
|
||||
],
|
||||
)
|
||||
challenges = sorted_by_enum_index(
|
||||
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
|
||||
)
|
||||
|
||||
if only_names:
|
||||
if output_json:
|
||||
click.echo(json.dumps([c.info.name for c in challenges]))
|
||||
return
|
||||
|
||||
for c in challenges:
|
||||
click.echo(
|
||||
click.style(c.info.name, fg=None if c.info.available else "black")
|
||||
)
|
||||
return
|
||||
|
||||
if output_json:
|
||||
click.echo(
|
||||
json.dumps([json.loads(c.info.model_dump_json()) for c in challenges])
|
||||
)
|
||||
return
|
||||
|
||||
headers = tuple(
|
||||
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
|
||||
)
|
||||
table = [
|
||||
tuple(
|
||||
v if challenge.info.available else click.style(v, fg="black")
|
||||
for v in (
|
||||
challenge.info.name,
|
||||
(
|
||||
click.style(
|
||||
challenge.info.difficulty.value,
|
||||
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
|
||||
)
|
||||
if challenge.info.difficulty
|
||||
else click.style("-", fg="black")
|
||||
),
|
||||
" ".join(
|
||||
click.style(cat.value, fg=CATEGORY_COLORS[cat])
|
||||
for cat in sorted_by_enum_index(challenge.info.category, Category)
|
||||
),
|
||||
)
|
||||
)
|
||||
for challenge in challenges
|
||||
]
|
||||
click.echo(tabulate(table, headers=headers))
|
||||
|
||||
|
||||
@challenge.command()
|
||||
@click.option("--json", is_flag=True)
|
||||
@click.argument("name")
|
||||
def info(name: str, json: bool):
|
||||
from itertools import chain
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
for challenge in chain(
|
||||
load_builtin_challenges(),
|
||||
load_webarena_challenges(skip_unavailable=False),
|
||||
):
|
||||
if challenge.info.name != name:
|
||||
continue
|
||||
|
||||
if json:
|
||||
click.echo(challenge.info.model_dump_json())
|
||||
break
|
||||
|
||||
pretty_print_model(challenge.info)
|
||||
break
|
||||
else:
|
||||
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def version():
|
||||
"""Print version info for the AGBenchmark application."""
|
||||
import toml
|
||||
|
||||
package_root = Path(__file__).resolve().parent.parent
|
||||
pyproject = toml.load(package_root / "pyproject.toml")
|
||||
version = pyproject["tool"]["poetry"]["version"]
|
||||
click.echo(f"AGBenchmark version {version}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -1,111 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from agent_protocol_client import (
|
||||
AgentApi,
|
||||
ApiClient,
|
||||
Configuration,
|
||||
Step,
|
||||
TaskRequestBody,
|
||||
)
|
||||
|
||||
from agbenchmark.agent_interface import get_list_of_file_paths
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_api_agent(
|
||||
task: str,
|
||||
config: AgentBenchmarkConfig,
|
||||
timeout: int,
|
||||
artifacts_location: Optional[Path] = None,
|
||||
*,
|
||||
mock: bool = False,
|
||||
) -> AsyncIterator[Step]:
|
||||
configuration = Configuration(host=config.host)
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_request_body = TaskRequestBody(input=task, additional_input=None)
|
||||
|
||||
start_time = time.time()
|
||||
response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
task_id = response.task_id
|
||||
|
||||
if artifacts_location:
|
||||
logger.debug("Uploading task input artifacts to agent...")
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_in"
|
||||
)
|
||||
|
||||
logger.debug("Running agent until finished or timeout...")
|
||||
while True:
|
||||
step = await api_instance.execute_agent_task_step(task_id=task_id)
|
||||
yield step
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError("Time limit exceeded")
|
||||
if step and mock:
|
||||
step.is_last = True
|
||||
if not step or step.is_last:
|
||||
break
|
||||
|
||||
if artifacts_location:
|
||||
# In "mock" mode, we cheat by giving the correct artifacts to pass the test
|
||||
if mock:
|
||||
logger.debug("Uploading mock artifacts to agent...")
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_out"
|
||||
)
|
||||
|
||||
logger.debug("Downloading agent artifacts...")
|
||||
await download_agent_artifacts_into_folder(
|
||||
api_instance, task_id, config.temp_folder
|
||||
)
|
||||
|
||||
|
||||
async def download_agent_artifacts_into_folder(
|
||||
api_instance: AgentApi, task_id: str, folder: Path
|
||||
):
|
||||
artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)
|
||||
|
||||
for artifact in artifacts.artifacts:
|
||||
# current absolute path of the directory of the file
|
||||
if artifact.relative_path:
|
||||
path: str = (
|
||||
artifact.relative_path
|
||||
if not artifact.relative_path.startswith("/")
|
||||
else artifact.relative_path[1:]
|
||||
)
|
||||
folder = (folder / path).parent
|
||||
|
||||
if not folder.exists():
|
||||
folder.mkdir(parents=True)
|
||||
|
||||
file_path = folder / artifact.file_name
|
||||
logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = await api_instance.download_agent_task_artifact(
|
||||
task_id=task_id, artifact_id=artifact.artifact_id
|
||||
)
|
||||
|
||||
f.write(content)
|
||||
|
||||
|
||||
async def upload_artifacts(
|
||||
api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
|
||||
) -> None:
|
||||
for file_path in get_list_of_file_paths(artifacts_location, type):
|
||||
relative_path: Optional[str] = "/".join(
|
||||
str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1]
|
||||
)
|
||||
if not relative_path:
|
||||
relative_path = None
|
||||
|
||||
await api_instance.upload_agent_task_artifacts(
|
||||
task_id=task_id, file=str(file_path), relative_path=relative_path
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HELICONE_GRAPHQL_LOGS = os.getenv("HELICONE_GRAPHQL_LOGS", "").lower() == "true"
|
||||
|
||||
|
||||
def get_list_of_file_paths(
|
||||
challenge_dir_path: str | Path, artifact_folder_name: str
|
||||
) -> list[Path]:
|
||||
source_dir = Path(challenge_dir_path) / artifact_folder_name
|
||||
if not source_dir.exists():
|
||||
return []
|
||||
return list(source_dir.iterdir())
|
||||
|
||||
|
||||
def copy_challenge_artifacts_into_workspace(
|
||||
challenge_dir_path: str | Path, artifact_folder_name: str, workspace: str | Path
|
||||
) -> None:
|
||||
file_paths = get_list_of_file_paths(challenge_dir_path, artifact_folder_name)
|
||||
for file_path in file_paths:
|
||||
if file_path.is_file():
|
||||
shutil.copy(file_path, workspace)
|
||||
@@ -1,339 +0,0 @@
|
||||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import psutil
|
||||
from agent_protocol_client import AgentApi, ApiClient, ApiException, Configuration
|
||||
from agent_protocol_client.models import Task, TaskRequestBody
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
from agbenchmark.challenges import ChallengeInfo
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types_v2 import (
|
||||
BenchmarkRun,
|
||||
Metrics,
|
||||
RepositoryInfo,
|
||||
RunDetails,
|
||||
TaskInfo,
|
||||
)
|
||||
from agbenchmark.schema import TaskEvalRequestBody
|
||||
from agbenchmark.utils.utils import write_pretty_json
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHALLENGES: dict[str, ChallengeInfo] = {}
|
||||
challenges_path = Path(__file__).parent / "challenges"
|
||||
challenge_spec_files = deque(
|
||||
glob.glob(
|
||||
f"{challenges_path}/**/data.json",
|
||||
recursive=True,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Loading challenges...")
|
||||
while challenge_spec_files:
|
||||
challenge_spec_file = Path(challenge_spec_files.popleft())
|
||||
challenge_relpath = challenge_spec_file.relative_to(challenges_path.parent)
|
||||
if challenge_relpath.is_relative_to("challenges/deprecated"):
|
||||
continue
|
||||
|
||||
logger.debug(f"Loading {challenge_relpath}...")
|
||||
try:
|
||||
challenge_info = ChallengeInfo.model_validate_json(
|
||||
challenge_spec_file.read_text()
|
||||
)
|
||||
except ValidationError as e:
|
||||
if logging.getLogger().level == logging.DEBUG:
|
||||
logger.warning(f"Spec file {challenge_relpath} failed to load:\n{e}")
|
||||
logger.debug(f"Invalid challenge spec: {challenge_spec_file.read_text()}")
|
||||
continue
|
||||
|
||||
if not challenge_info.eval_id:
|
||||
challenge_info.eval_id = str(uuid.uuid4())
|
||||
# this will sort all the keys of the JSON systematically
|
||||
# so that the order is always the same
|
||||
write_pretty_json(challenge_info.model_dump(), challenge_spec_file)
|
||||
|
||||
CHALLENGES[challenge_info.eval_id] = challenge_info
|
||||
|
||||
|
||||
class BenchmarkTaskInfo(BaseModel):
|
||||
task_id: str
|
||||
start_time: datetime.datetime
|
||||
challenge_info: ChallengeInfo
|
||||
|
||||
|
||||
task_informations: dict[str, BenchmarkTaskInfo] = {}
|
||||
|
||||
|
||||
def find_agbenchmark_without_uvicorn():
|
||||
pids = []
|
||||
for process in psutil.process_iter(
|
||||
attrs=[
|
||||
"pid",
|
||||
"cmdline",
|
||||
"name",
|
||||
"username",
|
||||
"status",
|
||||
"cpu_percent",
|
||||
"memory_info",
|
||||
"create_time",
|
||||
"cwd",
|
||||
"connections",
|
||||
]
|
||||
):
|
||||
try:
|
||||
# Convert the process.info dictionary values to strings and concatenate them
|
||||
full_info = " ".join([str(v) for k, v in process.as_dict().items()])
|
||||
|
||||
if "agbenchmark" in full_info and "uvicorn" not in full_info:
|
||||
pids.append(process.pid)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
return pids
|
||||
|
||||
|
||||
class CreateReportRequest(BaseModel):
|
||||
test: str
|
||||
test_run_id: str
|
||||
# category: Optional[str] = []
|
||||
mock: Optional[bool] = False
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
updates_list = []
|
||||
|
||||
origins = [
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:5000",
|
||||
]
|
||||
|
||||
|
||||
def stream_output(pipe):
|
||||
for line in pipe:
|
||||
print(line, end="")
|
||||
|
||||
|
||||
def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
from agbenchmark.agent_api_interface import upload_artifacts
|
||||
from agbenchmark.challenges import get_challenge_from_source_uri
|
||||
from agbenchmark.main import run_benchmark
|
||||
|
||||
configuration = Configuration(
|
||||
host=agbenchmark_config.host or "http://localhost:8000"
|
||||
)
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/reports")
|
||||
def run_single_test(body: CreateReportRequest) -> dict:
|
||||
pids = find_agbenchmark_without_uvicorn()
|
||||
logger.info(f"pids already running with agbenchmark: {pids}")
|
||||
|
||||
logger.debug(f"Request to /reports: {body.model_dump()}")
|
||||
|
||||
# Start the benchmark in a separate thread
|
||||
benchmark_process = Process(
|
||||
target=lambda: run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
tests=(body.test,),
|
||||
mock=body.mock or False,
|
||||
)
|
||||
)
|
||||
benchmark_process.start()
|
||||
|
||||
# Wait for the benchmark to finish, with a timeout of 200 seconds
|
||||
timeout = 200
|
||||
start_time = time.time()
|
||||
while benchmark_process.is_alive():
|
||||
if time.time() - start_time > timeout:
|
||||
logger.warning(f"Benchmark run timed out after {timeout} seconds")
|
||||
benchmark_process.terminate()
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
||||
|
||||
# List all folders in the current working directory
|
||||
reports_folder = agbenchmark_config.reports_folder
|
||||
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
|
||||
|
||||
# Sort the folders based on their names
|
||||
sorted_folders = sorted(folders, key=lambda x: x.name)
|
||||
|
||||
# Get the last folder
|
||||
latest_folder = sorted_folders[-1] if sorted_folders else None
|
||||
|
||||
# Read report.json from this folder
|
||||
if latest_folder:
|
||||
report_path = latest_folder / "report.json"
|
||||
logger.debug(f"Getting latest report from {report_path}")
|
||||
if report_path.exists():
|
||||
with report_path.open() as file:
|
||||
data = json.load(file)
|
||||
logger.debug(f"Report data: {data}")
|
||||
else:
|
||||
raise HTTPException(
|
||||
502,
|
||||
"Could not get result after running benchmark: "
|
||||
f"'report.json' does not exist in '{latest_folder}'",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
504, "Could not get result after running benchmark: no reports found"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@router.post("/agent/tasks", tags=["agent"])
|
||||
async def create_agent_task(task_eval_request: TaskEvalRequestBody) -> Task:
|
||||
"""
|
||||
Creates a new task using the provided TaskEvalRequestBody and returns a Task.
|
||||
|
||||
Args:
|
||||
task_eval_request: `TaskRequestBody` including an eval_id.
|
||||
|
||||
Returns:
|
||||
Task: A new task with task_id, input, additional_input,
|
||||
and empty lists for artifacts and steps.
|
||||
|
||||
Example:
|
||||
Request (TaskEvalRequestBody defined in schema.py):
|
||||
{
|
||||
...,
|
||||
"eval_id": "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
}
|
||||
|
||||
Response (Task defined in `agent_protocol_client.models`):
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"artifacts": []
|
||||
}
|
||||
"""
|
||||
try:
|
||||
challenge_info = CHALLENGES[task_eval_request.eval_id]
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_input = challenge_info.task
|
||||
|
||||
task_request_body = TaskRequestBody(
|
||||
input=task_input, additional_input=None
|
||||
)
|
||||
task_response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
task_info = BenchmarkTaskInfo(
|
||||
task_id=task_response.task_id,
|
||||
start_time=datetime.datetime.now(datetime.timezone.utc),
|
||||
challenge_info=challenge_info,
|
||||
)
|
||||
task_informations[task_info.task_id] = task_info
|
||||
|
||||
if input_artifacts_dir := challenge_info.task_artifacts_dir:
|
||||
await upload_artifacts(
|
||||
api_instance,
|
||||
input_artifacts_dir,
|
||||
task_response.task_id,
|
||||
"artifacts_in",
|
||||
)
|
||||
return task_response
|
||||
except ApiException as e:
|
||||
logger.error(f"Error whilst trying to create a task:\n{e}")
|
||||
logger.error(
|
||||
"The above error was caused while processing request: "
|
||||
f"{task_eval_request}"
|
||||
)
|
||||
raise HTTPException(500)
|
||||
|
||||
@router.post("/agent/tasks/{task_id}/steps")
|
||||
async def proxy(request: Request, task_id: str):
|
||||
timeout = httpx.Timeout(300.0, read=300.0) # 5 minutes
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
# Construct the new URL
|
||||
new_url = f"{configuration.host}/ap/v1/agent/tasks/{task_id}/steps"
|
||||
|
||||
# Forward the request
|
||||
response = await client.post(
|
||||
new_url,
|
||||
content=await request.body(),
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
# Return the response from the forwarded request
|
||||
return Response(content=response.content, status_code=response.status_code)
|
||||
|
||||
@router.post("/agent/tasks/{task_id}/evaluations")
|
||||
async def create_evaluation(task_id: str) -> BenchmarkRun:
|
||||
task_info = task_informations[task_id]
|
||||
challenge = get_challenge_from_source_uri(task_info.challenge_info.source_uri)
|
||||
try:
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
eval_results = await challenge.evaluate_task_state(
|
||||
api_instance, task_id
|
||||
)
|
||||
|
||||
eval_info = BenchmarkRun(
|
||||
repository_info=RepositoryInfo(),
|
||||
run_details=RunDetails(
|
||||
command=f"agbenchmark --test={challenge.info.name}",
|
||||
benchmark_start_time=(
|
||||
task_info.start_time.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||
),
|
||||
test_name=challenge.info.name,
|
||||
),
|
||||
task_info=TaskInfo(
|
||||
data_path=challenge.info.source_uri,
|
||||
is_regression=None,
|
||||
category=[c.value for c in challenge.info.category],
|
||||
task=challenge.info.task,
|
||||
answer=challenge.info.reference_answer or "",
|
||||
description=challenge.info.description or "",
|
||||
),
|
||||
metrics=Metrics(
|
||||
success=all(e.passed for e in eval_results),
|
||||
success_percentage=(
|
||||
100 * sum(e.score for e in eval_results) / len(eval_results)
|
||||
if eval_results # avoid division by 0
|
||||
else 0
|
||||
),
|
||||
attempted=True,
|
||||
),
|
||||
config={},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Returning evaluation data:\n{eval_info.model_dump_json(indent=4)}"
|
||||
)
|
||||
return eval_info
|
||||
except ApiException as e:
|
||||
logger.error(f"Error {e} whilst trying to evaluate task: {task_id}")
|
||||
raise HTTPException(500)
|
||||
|
||||
app.include_router(router, prefix="/ap/v1")
|
||||
|
||||
return app
|
||||
@@ -1,128 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
def _calculate_info_test_path(base_path: Path, benchmark_start_time: datetime) -> Path:
|
||||
"""
|
||||
Calculates the path to the directory where the test report will be saved.
|
||||
"""
|
||||
# Ensure the reports path exists
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get current UTC date-time stamp
|
||||
date_stamp = benchmark_start_time.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
# Default run name
|
||||
run_name = "full_run"
|
||||
|
||||
# Map command-line arguments to their respective labels
|
||||
arg_labels = {
|
||||
"--test": None,
|
||||
"--category": None,
|
||||
"--maintain": "maintain",
|
||||
"--improve": "improve",
|
||||
"--explore": "explore",
|
||||
}
|
||||
|
||||
# Identify the relevant command-line argument
|
||||
for arg, label in arg_labels.items():
|
||||
if arg in sys.argv:
|
||||
test_arg = sys.argv[sys.argv.index(arg) + 1] if label is None else None
|
||||
run_name = arg.strip("--")
|
||||
if test_arg:
|
||||
run_name = f"{run_name}_{test_arg}"
|
||||
break
|
||||
|
||||
# Create the full new directory path with ISO standard UTC date-time stamp
|
||||
report_path = base_path / f"{date_stamp}_{run_name}"
|
||||
|
||||
# Ensure the new directory is created
|
||||
# FIXME: this is not a desirable side-effect of loading the config
|
||||
report_path.mkdir(exist_ok=True)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
class AgentBenchmarkConfig(BaseSettings, extra="allow"):
|
||||
"""
|
||||
Configuration model and loader for the AGBenchmark.
|
||||
|
||||
Projects that want to use AGBenchmark should contain an agbenchmark_config folder
|
||||
with a config.json file that - at minimum - specifies the `host` at which the
|
||||
subject application exposes an Agent Protocol compliant API.
|
||||
"""
|
||||
|
||||
agbenchmark_config_dir: Path = Field(exclude=True)
|
||||
"""Path to the agbenchmark_config folder of the subject agent application."""
|
||||
|
||||
categories: list[str] | None = None
|
||||
"""Categories to benchmark the agent for. If omitted, all categories are assumed."""
|
||||
|
||||
host: str
|
||||
"""Host (scheme://address:port) of the subject agent application."""
|
||||
|
||||
reports_folder: Path = Field(None)
|
||||
"""
|
||||
Path to the folder where new reports should be stored.
|
||||
Defaults to {agbenchmark_config_dir}/reports.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_dir: Optional[Path] = None) -> "AgentBenchmarkConfig":
|
||||
config_dir = config_dir or cls.find_config_folder()
|
||||
with (config_dir / "config.json").open("r") as f:
|
||||
return cls(
|
||||
agbenchmark_config_dir=config_dir,
|
||||
**json.load(f),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_config_folder(for_dir: Path = Path.cwd()) -> Path:
|
||||
"""
|
||||
Find the closest ancestor folder containing an agbenchmark_config folder,
|
||||
and returns the path of that agbenchmark_config folder.
|
||||
"""
|
||||
current_directory = for_dir
|
||||
while current_directory != Path("/"):
|
||||
if (path := current_directory / "agbenchmark_config").exists():
|
||||
if (path / "config.json").is_file():
|
||||
return path
|
||||
current_directory = current_directory.parent
|
||||
raise FileNotFoundError(
|
||||
"No 'agbenchmark_config' directory found in the path hierarchy."
|
||||
)
|
||||
|
||||
@property
|
||||
def config_file(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "config.json"
|
||||
|
||||
@field_validator("reports_folder", mode="before")
|
||||
def set_reports_folder(cls, value: Path, info: ValidationInfo):
|
||||
if not value:
|
||||
return info.data["agbenchmark_config_dir"] / "reports"
|
||||
return value
|
||||
|
||||
def get_report_dir(self, benchmark_start_time: datetime) -> Path:
|
||||
return _calculate_info_test_path(self.reports_folder, benchmark_start_time)
|
||||
|
||||
@property
|
||||
def regression_tests_file(self) -> Path:
|
||||
return self.reports_folder / "regression_tests.json"
|
||||
|
||||
@property
|
||||
def success_rate_file(self) -> Path:
|
||||
return self.reports_folder / "success_rate.json"
|
||||
|
||||
@property
|
||||
def challenges_already_beaten_file(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "challenges_already_beaten.json"
|
||||
|
||||
@property
|
||||
def temp_folder(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "temp_folder"
|
||||
@@ -1,339 +0,0 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from agbenchmark.challenges import OPTIONAL_CATEGORIES, BaseChallenge
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types import Test
|
||||
from agbenchmark.reports.ReportManager import RegressionTestsTracker
|
||||
from agbenchmark.reports.reports import (
|
||||
add_test_result_to_report,
|
||||
make_empty_test_report,
|
||||
session_finish,
|
||||
)
|
||||
from agbenchmark.utils.data_types import Category
|
||||
|
||||
GLOBAL_TIMEOUT = (
|
||||
1500 # The tests will stop after 25 minutes so we can send the reports.
|
||||
)
|
||||
|
||||
agbenchmark_config = AgentBenchmarkConfig.load()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
pytest_plugins = ["agbenchmark.utils.dependencies"]
|
||||
collect_ignore = ["challenges"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config() -> AgentBenchmarkConfig:
|
||||
return agbenchmark_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def temp_folder() -> Generator[Path, None, None]:
|
||||
"""
|
||||
Pytest fixture that sets up and tears down the temporary folder for each test.
|
||||
It is automatically used in every test due to the 'autouse=True' parameter.
|
||||
"""
|
||||
|
||||
# create output directory if it doesn't exist
|
||||
if not os.path.exists(agbenchmark_config.temp_folder):
|
||||
os.makedirs(agbenchmark_config.temp_folder, exist_ok=True)
|
||||
|
||||
yield agbenchmark_config.temp_folder
|
||||
# teardown after test function completes
|
||||
if not os.getenv("KEEP_TEMP_FOLDER_FILES"):
|
||||
for filename in os.listdir(agbenchmark_config.temp_folder):
|
||||
file_path = os.path.join(agbenchmark_config.temp_folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_path}. Reason: {e}")
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Pytest hook that adds command-line options to the `pytest` command.
|
||||
The added options are specific to agbenchmark and control its behavior:
|
||||
* `--mock` is used to run the tests in mock mode.
|
||||
* `--host` is used to specify the host for the tests.
|
||||
* `--category` is used to run only tests of a specific category.
|
||||
* `--nc` is used to run the tests without caching.
|
||||
* `--cutoff` is used to specify a cutoff time for the tests.
|
||||
* `--improve` is used to run only the tests that are marked for improvement.
|
||||
* `--maintain` is used to run only the tests that are marked for maintenance.
|
||||
* `--explore` is used to run the tests in exploration mode.
|
||||
* `--test` is used to run a specific test.
|
||||
* `--no-dep` is used to run the tests without dependencies.
|
||||
* `--keep-answers` is used to keep the answers of the tests.
|
||||
|
||||
Args:
|
||||
parser: The Pytest CLI parser to which the command-line options are added.
|
||||
"""
|
||||
parser.addoption("-N", "--attempts", action="store")
|
||||
parser.addoption("--no-dep", action="store_true")
|
||||
parser.addoption("--mock", action="store_true")
|
||||
parser.addoption("--host", default=None)
|
||||
parser.addoption("--nc", action="store_true")
|
||||
parser.addoption("--cutoff", action="store")
|
||||
parser.addoption("--category", action="append")
|
||||
parser.addoption("--test", action="append")
|
||||
parser.addoption("--improve", action="store_true")
|
||||
parser.addoption("--maintain", action="store_true")
|
||||
parser.addoption("--explore", action="store_true")
|
||||
parser.addoption("--keep-answers", action="store_true")
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
# Register category markers to prevent "unknown marker" warnings
|
||||
for category in Category:
|
||||
config.addinivalue_line("markers", f"{category.value}: {category}")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_regression(request: pytest.FixtureRequest) -> None:
|
||||
"""
|
||||
Fixture that checks for every test if it should be treated as a regression test,
|
||||
and whether to skip it based on that.
|
||||
|
||||
The test name is retrieved from the `request` object. Regression reports are loaded
|
||||
from the path specified in the benchmark configuration.
|
||||
|
||||
Effect:
|
||||
* If the `--improve` option is used and the current test is considered a regression
|
||||
test, it is skipped.
|
||||
* If the `--maintain` option is used and the current test is not considered a
|
||||
regression test, it is also skipped.
|
||||
|
||||
Args:
|
||||
request: The request object from which the test name and the benchmark
|
||||
configuration are retrieved.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
assert isinstance(request.node, pytest.Function)
|
||||
assert isinstance(request.node.parent, pytest.Class)
|
||||
test_name = request.node.parent.name
|
||||
challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
|
||||
skip_string = f"Skipping {test_name} at {challenge_location}"
|
||||
|
||||
# Check if the test name exists in the regression tests
|
||||
is_regression_test = rt_tracker.has_regression_test(test_name)
|
||||
if request.config.getoption("--improve") and is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's a regression test")
|
||||
elif request.config.getoption("--maintain") and not is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's not a regression test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def mock(request: pytest.FixtureRequest) -> bool:
|
||||
"""
|
||||
Pytest fixture that retrieves the value of the `--mock` command-line option.
|
||||
The `--mock` option is used to run the tests in mock mode.
|
||||
|
||||
Args:
|
||||
request: The `pytest.FixtureRequest` from which the `--mock` option value
|
||||
is retrieved.
|
||||
|
||||
Returns:
|
||||
bool: Whether `--mock` is set for this session.
|
||||
"""
|
||||
mock = request.config.getoption("--mock")
|
||||
assert isinstance(mock, bool)
|
||||
return mock
|
||||
|
||||
|
||||
test_reports: dict[str, Test] = {}
|
||||
|
||||
|
||||
def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None:
|
||||
"""
|
||||
Pytest hook that is called when a test report is being generated.
|
||||
It is used to generate and finalize reports for each test.
|
||||
|
||||
Args:
|
||||
item: The test item for which the report is being generated.
|
||||
call: The call object from which the test result is retrieved.
|
||||
"""
|
||||
challenge: type[BaseChallenge] = item.cls # type: ignore
|
||||
challenge_id = challenge.info.eval_id
|
||||
|
||||
if challenge_id not in test_reports:
|
||||
test_reports[challenge_id] = make_empty_test_report(challenge.info)
|
||||
|
||||
if call.when == "setup":
|
||||
test_name = item.nodeid.split("::")[1]
|
||||
item.user_properties.append(("test_name", test_name))
|
||||
|
||||
if call.when == "call":
|
||||
add_test_result_to_report(
|
||||
test_reports[challenge_id], item, call, agbenchmark_config
|
||||
)
|
||||
|
||||
|
||||
def timeout_monitor(start_time: int) -> None:
|
||||
"""
|
||||
Function that limits the total execution time of the test suite.
|
||||
This function is supposed to be run in a separate thread and calls `pytest.exit`
|
||||
if the total execution time has exceeded the global timeout.
|
||||
|
||||
Args:
|
||||
start_time (int): The start time of the test suite.
|
||||
"""
|
||||
while time.time() - start_time < GLOBAL_TIMEOUT:
|
||||
time.sleep(1) # check every second
|
||||
|
||||
pytest.exit("Test suite exceeded the global timeout", returncode=1)
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
"""
|
||||
Pytest hook that is called at the start of a test session.
|
||||
|
||||
Sets up and runs a `timeout_monitor` in a separate thread.
|
||||
"""
|
||||
start_time = time.time()
|
||||
t = threading.Thread(target=timeout_monitor, args=(start_time,))
|
||||
t.daemon = True # Daemon threads are abruptly stopped at shutdown
|
||||
t.start()
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
"""
|
||||
Pytest hook that is called at the end of a test session.
|
||||
|
||||
Finalizes and saves the test reports.
|
||||
"""
|
||||
session_finish(agbenchmark_config)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc):
|
||||
n = metafunc.config.getoption("-N")
|
||||
metafunc.parametrize("i_attempt", range(int(n)) if type(n) is str else [0])
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
items: list[pytest.Function], config: pytest.Config
|
||||
) -> None:
|
||||
"""
|
||||
Pytest hook that is called after initial test collection has been performed.
|
||||
Modifies the collected test items based on the agent benchmark configuration,
|
||||
adding the dependency marker and category markers.
|
||||
|
||||
Args:
|
||||
items: The collected test items to be modified.
|
||||
config: The active pytest configuration.
|
||||
"""
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
try:
|
||||
challenges_beaten_in_the_past = json.loads(
|
||||
agbenchmark_config.challenges_already_beaten_file.read_bytes()
|
||||
)
|
||||
except FileNotFoundError:
|
||||
challenges_beaten_in_the_past = {}
|
||||
|
||||
selected_tests: tuple[str] = config.getoption("--test") # type: ignore
|
||||
selected_categories: tuple[str] = config.getoption("--category") # type: ignore
|
||||
|
||||
# Can't use a for-loop to remove items in-place
|
||||
i = 0
|
||||
while i < len(items):
|
||||
item = items[i]
|
||||
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||
challenge = item.cls
|
||||
challenge_name = challenge.info.name
|
||||
|
||||
if not issubclass(challenge, BaseChallenge):
|
||||
item.warn(
|
||||
pytest.PytestCollectionWarning(
|
||||
f"Non-challenge item collected: {challenge}"
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# --test: remove the test from the set if it's not specifically selected
|
||||
if selected_tests and challenge.info.name not in selected_tests:
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
# Filter challenges for --maintain, --improve, and --explore:
|
||||
# --maintain -> only challenges expected to be passed (= regression tests)
|
||||
# --improve -> only challenges that so far are not passed (reliably)
|
||||
# --explore -> only challenges that have never been passed
|
||||
is_regression_test = rt_tracker.has_regression_test(challenge.info.name)
|
||||
has_been_passed = challenges_beaten_in_the_past.get(challenge.info.name, False)
|
||||
if (
|
||||
(config.getoption("--maintain") and not is_regression_test)
|
||||
or (config.getoption("--improve") and is_regression_test)
|
||||
or (config.getoption("--explore") and has_been_passed)
|
||||
):
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
dependencies = challenge.info.dependencies
|
||||
if (
|
||||
config.getoption("--test")
|
||||
or config.getoption("--no-dep")
|
||||
or config.getoption("--maintain")
|
||||
):
|
||||
# Ignore dependencies:
|
||||
# --test -> user selected specific tests to run, don't care about deps
|
||||
# --no-dep -> ignore dependency relations regardless of test selection
|
||||
# --maintain -> all "regression" tests must pass, so run all of them
|
||||
dependencies = []
|
||||
elif config.getoption("--improve"):
|
||||
# Filter dependencies, keep only deps that are not "regression" tests
|
||||
dependencies = [
|
||||
d for d in dependencies if not rt_tracker.has_regression_test(d)
|
||||
]
|
||||
|
||||
# Set category markers
|
||||
challenge_categories = set(c.value for c in challenge.info.category)
|
||||
for category in challenge_categories:
|
||||
item.add_marker(category)
|
||||
|
||||
# Enforce category selection
|
||||
if selected_categories:
|
||||
if not challenge_categories.intersection(set(selected_categories)):
|
||||
items.remove(item)
|
||||
continue
|
||||
# # Filter dependencies, keep only deps from selected categories
|
||||
# dependencies = [
|
||||
# d for d in dependencies
|
||||
# if not set(d.categories).intersection(set(selected_categories))
|
||||
# ]
|
||||
|
||||
# Skip items in optional categories that are not selected for the subject agent
|
||||
challenge_optional_categories = challenge_categories & set(OPTIONAL_CATEGORIES)
|
||||
if challenge_optional_categories and not (
|
||||
agbenchmark_config.categories
|
||||
and challenge_optional_categories.issubset(
|
||||
set(agbenchmark_config.categories)
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"Skipping {challenge_name}: "
|
||||
f"category {' and '.join(challenge_optional_categories)} is optional, "
|
||||
"and not explicitly selected in the benchmark config."
|
||||
)
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
# Add marker for the DependencyManager
|
||||
item.add_marker(pytest.mark.depends(on=dependencies, name=challenge_name))
|
||||
|
||||
i += 1
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
AGBenchmark's test discovery endpoint for Pytest.
|
||||
|
||||
This module is picked up by Pytest's *_test.py file matching pattern, and all challenge
|
||||
classes in the module that conform to the `Test*` pattern are collected.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from itertools import chain
|
||||
|
||||
from agbenchmark.challenges.builtin import load_builtin_challenges
|
||||
from agbenchmark.challenges.webarena import load_webarena_challenges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_CATEGORY = {}
|
||||
|
||||
# Load challenges and attach them to this module
|
||||
for challenge in chain(load_builtin_challenges(), load_webarena_challenges()):
|
||||
# Attach the Challenge class to this module so it can be discovered by pytest
|
||||
module = importlib.import_module(__name__)
|
||||
setattr(module, challenge.__name__, challenge)
|
||||
|
||||
# Build a map of challenge names and their primary category
|
||||
DATA_CATEGORY[challenge.info.name] = challenge.info.category[0].value
|
||||
@@ -1,158 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agbenchmark.challenges import get_unique_categories
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
config: AgentBenchmarkConfig,
|
||||
maintain: bool = False,
|
||||
improve: bool = False,
|
||||
explore: bool = False,
|
||||
tests: tuple[str, ...] = tuple(),
|
||||
categories: tuple[str, ...] = tuple(),
|
||||
skip_categories: tuple[str, ...] = tuple(),
|
||||
attempts_per_challenge: int = 1,
|
||||
mock: bool = False,
|
||||
no_dep: bool = False,
|
||||
no_cutoff: bool = False,
|
||||
cutoff: Optional[int] = None,
|
||||
keep_answers: bool = False,
|
||||
server: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Starts the benchmark. If a category flag is provided, only challenges with the
|
||||
corresponding mark will be run.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from agbenchmark.reports.ReportManager import SingletonReportManager
|
||||
|
||||
validate_args(
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
tests=tests,
|
||||
categories=categories,
|
||||
skip_categories=skip_categories,
|
||||
no_cutoff=no_cutoff,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
|
||||
SingletonReportManager()
|
||||
|
||||
for key, value in vars(config).items():
|
||||
logger.debug(f"config.{key} = {repr(value)}")
|
||||
|
||||
pytest_args = ["-vs"]
|
||||
|
||||
if tests:
|
||||
logger.info(f"Running specific test(s): {' '.join(tests)}")
|
||||
pytest_args += [f"--test={t}" for t in tests]
|
||||
else:
|
||||
all_categories = get_unique_categories()
|
||||
|
||||
if categories or skip_categories:
|
||||
categories_to_run = set(categories) or all_categories
|
||||
if skip_categories:
|
||||
categories_to_run = categories_to_run.difference(set(skip_categories))
|
||||
assert categories_to_run, "Error: You can't skip all categories"
|
||||
pytest_args += [f"--category={c}" for c in categories_to_run]
|
||||
logger.info(f"Running tests of category: {categories_to_run}")
|
||||
else:
|
||||
logger.info("Running all categories")
|
||||
|
||||
if maintain:
|
||||
logger.info("Running only regression tests")
|
||||
elif improve:
|
||||
logger.info("Running only non-regression tests")
|
||||
elif explore:
|
||||
logger.info("Only attempt challenges that have never been beaten")
|
||||
|
||||
if mock:
|
||||
# TODO: unhack
|
||||
os.environ[
|
||||
"IS_MOCK"
|
||||
] = "True" # ugly hack to make the mock work when calling from API
|
||||
|
||||
# Pass through flags
|
||||
for flag, active in {
|
||||
"--maintain": maintain,
|
||||
"--improve": improve,
|
||||
"--explore": explore,
|
||||
"--no-dep": no_dep,
|
||||
"--mock": mock,
|
||||
"--nc": no_cutoff,
|
||||
"--keep-answers": keep_answers,
|
||||
}.items():
|
||||
if active:
|
||||
pytest_args.append(flag)
|
||||
|
||||
if attempts_per_challenge > 1:
|
||||
pytest_args.append(f"--attempts={attempts_per_challenge}")
|
||||
|
||||
if cutoff:
|
||||
pytest_args.append(f"--cutoff={cutoff}")
|
||||
logger.debug(f"Setting cuttoff override to {cutoff} seconds.")
|
||||
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
pytest_args.append(str(current_dir / "generate_test.py"))
|
||||
|
||||
pytest_args.append("--cache-clear")
|
||||
logger.debug(f"Running Pytest with args: {pytest_args}")
|
||||
exit_code = pytest.main(pytest_args)
|
||||
|
||||
SingletonReportManager.clear_instance()
|
||||
return exit_code
|
||||
|
||||
|
||||
class InvalidInvocationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_args(
|
||||
maintain: bool,
|
||||
improve: bool,
|
||||
explore: bool,
|
||||
tests: Sequence[str],
|
||||
categories: Sequence[str],
|
||||
skip_categories: Sequence[str],
|
||||
no_cutoff: bool,
|
||||
cutoff: Optional[int],
|
||||
) -> None:
|
||||
if categories:
|
||||
all_categories = get_unique_categories()
|
||||
invalid_categories = set(categories) - all_categories
|
||||
if invalid_categories:
|
||||
raise InvalidInvocationError(
|
||||
"One or more invalid categories were specified: "
|
||||
f"{', '.join(invalid_categories)}.\n"
|
||||
f"Valid categories are: {', '.join(all_categories)}."
|
||||
)
|
||||
|
||||
if (maintain + improve + explore) > 1:
|
||||
raise InvalidInvocationError(
|
||||
"You can't use --maintain, --improve or --explore at the same time. "
|
||||
"Please choose one."
|
||||
)
|
||||
|
||||
if tests and (categories or skip_categories or maintain or improve or explore):
|
||||
raise InvalidInvocationError(
|
||||
"If you're running a specific test make sure no other options are "
|
||||
"selected. Please just pass the --test."
|
||||
)
|
||||
|
||||
if no_cutoff and cutoff:
|
||||
raise InvalidInvocationError(
|
||||
"You can't use both --nc and --cutoff at the same time. "
|
||||
"Please choose one."
|
||||
)
|
||||
@@ -1,217 +0,0 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.graphs import save_single_radar_chart
|
||||
from agbenchmark.reports.processing.process_report import (
|
||||
get_highest_achieved_difficulty_per_category,
|
||||
)
|
||||
from agbenchmark.reports.processing.report_types import MetricsOverall, Report, Test
|
||||
from agbenchmark.utils.utils import get_highest_success_difficulty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingletonReportManager:
|
||||
instance = None
|
||||
|
||||
INFO_MANAGER: "SessionReportManager"
|
||||
REGRESSION_MANAGER: "RegressionTestsTracker"
|
||||
SUCCESS_RATE_TRACKER: "SuccessRatesTracker"
|
||||
|
||||
def __new__(cls):
|
||||
if not cls.instance:
|
||||
cls.instance = super(SingletonReportManager, cls).__new__(cls)
|
||||
|
||||
agent_benchmark_config = AgentBenchmarkConfig.load()
|
||||
benchmark_start_time_dt = datetime.now(
|
||||
timezone.utc
|
||||
) # or any logic to fetch the datetime
|
||||
|
||||
# Make the Managers class attributes
|
||||
cls.INFO_MANAGER = SessionReportManager(
|
||||
agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
|
||||
/ "report.json",
|
||||
benchmark_start_time_dt,
|
||||
)
|
||||
cls.REGRESSION_MANAGER = RegressionTestsTracker(
|
||||
agent_benchmark_config.regression_tests_file
|
||||
)
|
||||
cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
|
||||
agent_benchmark_config.success_rate_file
|
||||
)
|
||||
|
||||
return cls.instance
|
||||
|
||||
@classmethod
|
||||
def clear_instance(cls):
|
||||
cls.instance = None
|
||||
del cls.INFO_MANAGER
|
||||
del cls.REGRESSION_MANAGER
|
||||
del cls.SUCCESS_RATE_TRACKER
|
||||
|
||||
|
||||
class BaseReportManager:
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, Any]
|
||||
|
||||
def __init__(self, report_file: Path):
|
||||
self.report_file = report_file
|
||||
|
||||
self.load()
|
||||
|
||||
def load(self) -> None:
|
||||
if not self.report_file.exists():
|
||||
self.report_file.parent.mkdir(exist_ok=True)
|
||||
|
||||
try:
|
||||
with self.report_file.open("r") as f:
|
||||
data = json.load(f)
|
||||
self.tests = {k: data[k] for k in sorted(data)}
|
||||
except FileNotFoundError:
|
||||
self.tests = {}
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
logger.warning(f"Could not parse {self.report_file}: {e}")
|
||||
self.tests = {}
|
||||
|
||||
def save(self) -> None:
|
||||
with self.report_file.open("w") as f:
|
||||
json.dump(self.tests, f, indent=4)
|
||||
|
||||
def remove_test(self, test_name: str) -> None:
|
||||
if test_name in self.tests:
|
||||
del self.tests[test_name]
|
||||
self.save()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.tests = {}
|
||||
self.save()
|
||||
|
||||
|
||||
class SessionReportManager(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, Test]
|
||||
report: Report | None = None
|
||||
|
||||
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
||||
super().__init__(report_file)
|
||||
|
||||
self.start_time = time.time()
|
||||
self.benchmark_start_time = benchmark_start_time
|
||||
|
||||
def save(self) -> None:
|
||||
with self.report_file.open("w") as f:
|
||||
if self.report:
|
||||
f.write(self.report.model_dump_json(indent=4))
|
||||
else:
|
||||
json.dump(
|
||||
{k: v.model_dump() for k, v in self.tests.items()}, f, indent=4
|
||||
)
|
||||
|
||||
def load(self) -> None:
|
||||
super().load()
|
||||
|
||||
if "tests" in self.tests:
|
||||
self.report = Report.model_validate(self.tests)
|
||||
else:
|
||||
self.tests = {n: Test.model_validate(d) for n, d in self.tests.items()}
|
||||
|
||||
def add_test_report(self, test_name: str, test_report: Test) -> None:
|
||||
if self.report:
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
self.tests[test_name] = test_report
|
||||
|
||||
self.save()
|
||||
|
||||
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
|
||||
command = " ".join(sys.argv)
|
||||
|
||||
if self.report:
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
self.report = Report(
|
||||
command=command.split(os.sep)[-1],
|
||||
benchmark_git_commit_sha="---",
|
||||
agent_git_commit_sha="---",
|
||||
completion_time=datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S+00:00"
|
||||
),
|
||||
benchmark_start_time=self.benchmark_start_time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S+00:00"
|
||||
),
|
||||
metrics=MetricsOverall(
|
||||
run_time=str(round(time.time() - self.start_time, 2)) + " seconds",
|
||||
highest_difficulty=get_highest_success_difficulty(self.tests),
|
||||
total_cost=self.get_total_costs(),
|
||||
),
|
||||
tests=copy.copy(self.tests),
|
||||
config=config.model_dump(exclude={"reports_folder"}, exclude_none=True),
|
||||
)
|
||||
|
||||
agent_categories = get_highest_achieved_difficulty_per_category(self.report)
|
||||
if len(agent_categories) > 1:
|
||||
save_single_radar_chart(
|
||||
agent_categories,
|
||||
config.get_report_dir(self.benchmark_start_time) / "radar_chart.png",
|
||||
)
|
||||
|
||||
self.save()
|
||||
|
||||
def get_total_costs(self):
|
||||
if self.report:
|
||||
tests = self.report.tests
|
||||
else:
|
||||
tests = self.tests
|
||||
|
||||
total_cost = 0
|
||||
all_costs_none = True
|
||||
for test_data in tests.values():
|
||||
cost = sum(r.cost or 0 for r in test_data.results)
|
||||
|
||||
if cost is not None: # check if cost is not None
|
||||
all_costs_none = False
|
||||
total_cost += cost # add cost to total
|
||||
if all_costs_none:
|
||||
total_cost = None
|
||||
return total_cost
|
||||
|
||||
|
||||
class RegressionTestsTracker(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, dict]
|
||||
|
||||
def add_test(self, test_name: str, test_details: dict) -> None:
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
|
||||
self.tests[test_name] = test_details
|
||||
self.save()
|
||||
|
||||
def has_regression_test(self, test_name: str) -> bool:
|
||||
return self.tests.get(test_name) is not None
|
||||
|
||||
|
||||
class SuccessRatesTracker(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, list[bool | None]]
|
||||
|
||||
def update(self, test_name: str, success_history: list[bool | None]) -> None:
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
|
||||
self.tests[test_name] = success_history
|
||||
self.save()
|
||||
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from agbenchmark.reports.processing.graphs import (
|
||||
save_combined_bar_chart,
|
||||
save_combined_radar_chart,
|
||||
)
|
||||
from agbenchmark.reports.processing.process_report import (
|
||||
all_agent_categories,
|
||||
get_reports_data,
|
||||
)
|
||||
|
||||
|
||||
def generate_combined_chart() -> None:
|
||||
all_agents_path = Path(__file__).parent.parent.parent.parent / "reports"
|
||||
|
||||
combined_charts_folder = all_agents_path / "combined_charts"
|
||||
|
||||
reports_data = get_reports_data(str(all_agents_path))
|
||||
|
||||
categories = all_agent_categories(reports_data)
|
||||
|
||||
# Count the number of directories in this directory
|
||||
num_dirs = len([f for f in combined_charts_folder.iterdir() if f.is_dir()])
|
||||
|
||||
run_charts_folder = combined_charts_folder / f"run{num_dirs + 1}"
|
||||
|
||||
if not os.path.exists(run_charts_folder):
|
||||
os.makedirs(run_charts_folder)
|
||||
|
||||
info_data = {
|
||||
report_name: data.benchmark_start_time
|
||||
for report_name, data in reports_data.items()
|
||||
if report_name in categories
|
||||
}
|
||||
with open(Path(run_charts_folder) / "run_info.json", "w") as f:
|
||||
json.dump(info_data, f)
|
||||
|
||||
save_combined_radar_chart(categories, Path(run_charts_folder) / "radar_chart.png")
|
||||
save_combined_bar_chart(categories, Path(run_charts_folder) / "bar_chart.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_combined_chart()
|
||||
@@ -1,34 +0,0 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_last_subdirectory(directory_path: str) -> str | None:
|
||||
# Get all subdirectories in the directory
|
||||
subdirs = [
|
||||
os.path.join(directory_path, name)
|
||||
for name in os.listdir(directory_path)
|
||||
if os.path.isdir(os.path.join(directory_path, name))
|
||||
]
|
||||
|
||||
# Sort the subdirectories by creation time
|
||||
subdirs.sort(key=os.path.getctime)
|
||||
|
||||
# Return the last subdirectory in the list
|
||||
return subdirs[-1] if subdirs else None
|
||||
|
||||
|
||||
def get_latest_report_from_agent_directories(
|
||||
directory_path: str,
|
||||
) -> list[tuple[os.DirEntry[str], str]]:
|
||||
latest_reports = []
|
||||
|
||||
for subdir in os.scandir(directory_path):
|
||||
if subdir.is_dir():
|
||||
# Get the most recently created subdirectory within this agent's directory
|
||||
latest_subdir = get_last_subdirectory(subdir.path)
|
||||
if latest_subdir is not None:
|
||||
# Look for 'report.json' in the subdirectory
|
||||
report_file = os.path.join(latest_subdir, "report.json")
|
||||
if os.path.isfile(report_file):
|
||||
latest_reports.append((subdir, report_file))
|
||||
|
||||
return latest_reports
|
||||
@@ -1,199 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def save_combined_radar_chart(
|
||||
categories: dict[str, Any], save_path: str | Path
|
||||
) -> None:
|
||||
categories = {k: v for k, v in categories.items() if v}
|
||||
if not all(categories.values()):
|
||||
raise Exception("No data to plot")
|
||||
labels = np.array(
|
||||
list(next(iter(categories.values())).keys())
|
||||
) # We use the first category to get the keys
|
||||
num_vars = len(labels)
|
||||
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
||||
angles += angles[
|
||||
:1
|
||||
] # Add the first angle to the end of the list to ensure the polygon is closed
|
||||
|
||||
# Create radar chart
|
||||
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
||||
ax.set_theta_offset(np.pi / 2) # type: ignore
|
||||
ax.set_theta_direction(-1) # type: ignore
|
||||
ax.spines["polar"].set_visible(False) # Remove border
|
||||
|
||||
cmap = plt.cm.get_cmap("nipy_spectral", len(categories)) # type: ignore
|
||||
|
||||
colors = [cmap(i) for i in range(len(categories))]
|
||||
|
||||
for i, (cat_name, cat_values) in enumerate(
|
||||
categories.items()
|
||||
): # Iterating through each category (series)
|
||||
values = np.array(list(cat_values.values()))
|
||||
values = np.concatenate((values, values[:1])) # Ensure the polygon is closed
|
||||
|
||||
ax.fill(angles, values, color=colors[i], alpha=0.25) # Draw the filled polygon
|
||||
ax.plot(angles, values, color=colors[i], linewidth=2) # Draw polygon
|
||||
ax.plot(
|
||||
angles,
|
||||
values,
|
||||
"o",
|
||||
color="white",
|
||||
markersize=7,
|
||||
markeredgecolor=colors[i],
|
||||
markeredgewidth=2,
|
||||
) # Draw points
|
||||
|
||||
# Draw legend
|
||||
ax.legend(
|
||||
handles=[
|
||||
mpatches.Patch(color=color, label=cat_name, alpha=0.25)
|
||||
for cat_name, color in zip(categories.keys(), colors)
|
||||
],
|
||||
loc="upper left",
|
||||
bbox_to_anchor=(0.7, 1.3),
|
||||
)
|
||||
|
||||
# Adjust layout to make room for the legend
|
||||
plt.tight_layout()
|
||||
|
||||
lines, labels = plt.thetagrids(
|
||||
np.degrees(angles[:-1]), (list(next(iter(categories.values())).keys()))
|
||||
) # We use the first category to get the keys
|
||||
|
||||
highest_score = 7
|
||||
|
||||
# Set y-axis limit to 7
|
||||
ax.set_ylim(top=highest_score)
|
||||
|
||||
# Move labels away from the plot
|
||||
for label in labels:
|
||||
label.set_position(
|
||||
(label.get_position()[0], label.get_position()[1] + -0.05)
|
||||
) # adjust 0.1 as needed
|
||||
|
||||
# Move radial labels away from the plot
|
||||
ax.set_rlabel_position(180) # type: ignore
|
||||
|
||||
ax.set_yticks([]) # Remove default yticks
|
||||
|
||||
# Manually create gridlines
|
||||
for y in np.arange(0, highest_score + 1, 1):
|
||||
if y != highest_score:
|
||||
ax.plot(
|
||||
angles, [y] * len(angles), color="gray", linewidth=0.5, linestyle=":"
|
||||
)
|
||||
# Add labels for manually created gridlines
|
||||
ax.text(
|
||||
angles[0],
|
||||
y + 0.2,
|
||||
str(int(y)),
|
||||
color="black",
|
||||
size=9,
|
||||
horizontalalignment="center",
|
||||
verticalalignment="center",
|
||||
)
|
||||
|
||||
plt.savefig(save_path, dpi=300) # Save the figure as a PNG file
|
||||
plt.close() # Close the figure to free up memory
|
||||
|
||||
|
||||
def save_single_radar_chart(
|
||||
category_dict: dict[str, int], save_path: str | Path
|
||||
) -> None:
|
||||
labels = np.array(list(category_dict.keys()))
|
||||
values = np.array(list(category_dict.values()))
|
||||
|
||||
num_vars = len(labels)
|
||||
|
||||
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
||||
|
||||
angles += angles[:1]
|
||||
values = np.concatenate((values, values[:1]))
|
||||
|
||||
colors = ["#1f77b4"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
|
||||
ax.set_theta_offset(np.pi / 2) # type: ignore
|
||||
ax.set_theta_direction(-1) # type: ignore
|
||||
|
||||
ax.spines["polar"].set_visible(False)
|
||||
|
||||
lines, labels = plt.thetagrids(
|
||||
np.degrees(angles[:-1]), (list(category_dict.keys()))
|
||||
)
|
||||
|
||||
highest_score = 7
|
||||
|
||||
# Set y-axis limit to 7
|
||||
ax.set_ylim(top=highest_score)
|
||||
|
||||
for label in labels:
|
||||
label.set_position((label.get_position()[0], label.get_position()[1] + -0.05))
|
||||
|
||||
ax.fill(angles, values, color=colors[0], alpha=0.25)
|
||||
ax.plot(angles, values, color=colors[0], linewidth=2)
|
||||
|
||||
for i, (angle, value) in enumerate(zip(angles, values)):
|
||||
ha = "left"
|
||||
if angle in {0, np.pi}:
|
||||
ha = "center"
|
||||
elif np.pi < angle < 2 * np.pi:
|
||||
ha = "right"
|
||||
ax.text(
|
||||
angle,
|
||||
value - 0.5,
|
||||
f"{value}",
|
||||
size=10,
|
||||
horizontalalignment=ha,
|
||||
verticalalignment="center",
|
||||
color="black",
|
||||
)
|
||||
|
||||
ax.set_yticklabels([])
|
||||
|
||||
ax.set_yticks([])
|
||||
|
||||
if values.size == 0:
|
||||
return
|
||||
|
||||
for y in np.arange(0, highest_score, 1):
|
||||
ax.plot(angles, [y] * len(angles), color="gray", linewidth=0.5, linestyle=":")
|
||||
|
||||
for angle, value in zip(angles, values):
|
||||
ax.plot(
|
||||
angle,
|
||||
value,
|
||||
"o",
|
||||
color="white",
|
||||
markersize=7,
|
||||
markeredgecolor=colors[0],
|
||||
markeredgewidth=2,
|
||||
)
|
||||
|
||||
plt.savefig(save_path, dpi=300) # Save the figure as a PNG file
|
||||
plt.close() # Close the figure to free up memory
|
||||
|
||||
|
||||
def save_combined_bar_chart(categories: dict[str, Any], save_path: str | Path) -> None:
|
||||
if not all(categories.values()):
|
||||
raise Exception("No data to plot")
|
||||
|
||||
# Convert dictionary to DataFrame
|
||||
df = pd.DataFrame(categories)
|
||||
|
||||
# Create a grouped bar chart
|
||||
df.plot(kind="bar", figsize=(10, 7))
|
||||
|
||||
plt.title("Performance by Category for Each Agent")
|
||||
plt.xlabel("Category")
|
||||
plt.ylabel("Performance")
|
||||
|
||||
plt.savefig(save_path, dpi=300) # Save the figure as a PNG file
|
||||
plt.close() # Close the figure to free up memory
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user