mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
76 Commits
abhi/check
...
make-old-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8d87f2853 | ||
|
|
3b822cdaf7 | ||
|
|
791e1d8982 | ||
|
|
b2eb4831bd | ||
|
|
4cd5da678d | ||
|
|
b94c83aacc | ||
|
|
7668c17d9c | ||
|
|
0040636948 | ||
|
|
c671af851f | ||
|
|
7dd181f4b0 | ||
|
|
114856cef1 | ||
|
|
68b9bd0c51 | ||
|
|
ff076b1f15 | ||
|
|
57fbab500b | ||
|
|
6faabef24d | ||
|
|
a67d475a69 | ||
|
|
326554d89a | ||
|
|
5e22a1888a | ||
|
|
a4d7b0142f | ||
|
|
7d6375f59c | ||
|
|
aeec0ce509 | ||
|
|
b32bfcaac5 | ||
|
|
5373a6eb6e | ||
|
|
98cde46ccb | ||
|
|
bd10da10d9 | ||
|
|
60fdee1345 | ||
|
|
6f2783468c | ||
|
|
c1031b286d | ||
|
|
b849eafb7f | ||
|
|
572c3f5e0d | ||
|
|
89003a585d | ||
|
|
0e65785228 | ||
|
|
f07dff1cdd | ||
|
|
00e02a4696 | ||
|
|
634bff8277 | ||
|
|
d591f36c7b | ||
|
|
a347bed0b1 | ||
|
|
4eeb6ee2b0 | ||
|
|
7db962b9f9 | ||
|
|
9108b21541 | ||
|
|
ffe9325296 | ||
|
|
0a616d9267 | ||
|
|
ab95077e5b | ||
|
|
e477150979 | ||
|
|
804430e243 | ||
|
|
acb320d32d | ||
|
|
32f68d5999 | ||
|
|
49f56b4e8d | ||
|
|
bead811e73 | ||
|
|
013f728ebf | ||
|
|
cda9572acd | ||
|
|
e0784f8f6b | ||
|
|
3040f39136 | ||
|
|
515504c604 | ||
|
|
18edeaeaf4 | ||
|
|
44182aff9c | ||
|
|
864c5a7846 | ||
|
|
699fffb1a8 | ||
|
|
f0641c2d26 | ||
|
|
94b6f74c95 | ||
|
|
46aabab3ea | ||
|
|
0a65df5102 | ||
|
|
6fbd208fe3 | ||
|
|
8fc174ca87 | ||
|
|
cacc89790f | ||
|
|
b9113bee02 | ||
|
|
3f65da03e7 | ||
|
|
9e96d11b2d | ||
|
|
4c264b7ae9 | ||
|
|
0adbc0bd05 | ||
|
|
8f3291bc92 | ||
|
|
7a20de880d | ||
|
|
ef8a6d2528 | ||
|
|
fd66be2aaa | ||
|
|
ae2cc97dc4 | ||
|
|
ea521eed26 |
@@ -29,8 +29,7 @@
|
||||
"postCreateCmd": [
|
||||
"cd autogpt_platform/autogpt_libs && poetry install",
|
||||
"cd autogpt_platform/backend && poetry install && poetry run prisma generate",
|
||||
"cd autogpt_platform/frontend && pnpm install",
|
||||
"cd docs && pip install -r requirements.txt"
|
||||
"cd autogpt_platform/frontend && pnpm install"
|
||||
],
|
||||
"terminalCommand": "code .",
|
||||
"deleteBranchWithWorktree": false
|
||||
|
||||
6
.github/copilot-instructions.md
vendored
6
.github/copilot-instructions.md
vendored
@@ -160,7 +160,7 @@ pnpm storybook # Start component development server
|
||||
|
||||
**Backend Entry Points:**
|
||||
|
||||
- `backend/backend/server/server.py` - FastAPI application setup
|
||||
- `backend/backend/api/rest_api.py` - FastAPI application setup
|
||||
- `backend/backend/data/` - Database models and user management
|
||||
- `backend/blocks/` - Agent execution blocks and logic
|
||||
|
||||
@@ -219,7 +219,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
### API Development
|
||||
|
||||
1. Update routes in `/backend/backend/server/routers/`
|
||||
1. Update routes in `/backend/backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside route files
|
||||
4. For `data/*.py` changes, validate user ID checks
|
||||
@@ -285,7 +285,7 @@ Agents are built using a visual block-based system where each block performs a s
|
||||
|
||||
### Security Guidelines
|
||||
|
||||
**Cache Protection Middleware** (`/backend/backend/server/middleware/security.py`):
|
||||
**Cache Protection Middleware** (`/backend/backend/api/middleware/security.py`):
|
||||
|
||||
- Default: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses allow list approach for cacheable paths (static assets, health checks, public pages)
|
||||
|
||||
73
.github/workflows/classic-autogpt-ci.yml
vendored
73
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,11 +6,15 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-autogpt-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/forge/**'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -19,47 +23,22 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/original_autogpt
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -71,41 +50,23 @@ jobs:
|
||||
git config --global user.name "Auto-GPT-Bot"
|
||||
git config --global user.email "github-bot@agpt.co"
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -116,12 +77,12 @@ jobs:
|
||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests/unit tests/integration
|
||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -135,11 +96,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
flags: autogpt-agent
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/original_autogpt/logs/
|
||||
path: classic/logs/
|
||||
|
||||
36
.github/workflows/classic-autogpts-ci.yml
vendored
36
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,9 +11,6 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
@@ -22,9 +19,6 @@ on:
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/run'
|
||||
- 'classic/cli.py'
|
||||
- 'classic/setup.py'
|
||||
- '!**/*.md'
|
||||
|
||||
defaults:
|
||||
@@ -35,13 +29,9 @@ defaults:
|
||||
jobs:
|
||||
serve-agent-protocol:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ original_autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,22 +45,22 @@ jobs:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Run regression tests
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run smoke tests with direct-benchmark
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
CI: true
|
||||
|
||||
194
.github/workflows/classic-benchmark-ci.yml
vendored
194
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,17 +1,21 @@
|
||||
name: Classic - AGBenchmark CI
|
||||
name: Classic - Direct Benchmark CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, dev, ci-test* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- 'classic/benchmark/**'
|
||||
- '!classic/benchmark/reports/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- .github/workflows/classic-benchmark-ci.yml
|
||||
|
||||
concurrency:
|
||||
@@ -23,23 +27,16 @@ defaults:
|
||||
shell: bash
|
||||
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
benchmark-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/benchmark
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -47,71 +44,88 @@ jobs:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
- name: Run basic benchmark tests
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
tests
|
||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--json
|
||||
|
||||
echo "Testing WriteFile challenge..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--tests WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Upload test results to Codecov
|
||||
if: ${{ !cancelled() }} # Run even if tests fail
|
||||
uses: codecov/test-results-action@v1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
- name: Test category filtering
|
||||
run: |
|
||||
echo "Testing coding category..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--categories coding \
|
||||
--tests ReadFile,WriteFile \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: agbenchmark,${{ runner.os }}
|
||||
- name: Test multiple strategies
|
||||
run: |
|
||||
echo "Testing multiple strategies..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot,plan_execute \
|
||||
--models claude \
|
||||
--tests ReadFile \
|
||||
--parallel 2 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
self-test-with-agent:
|
||||
# Run regression tests on maintain challenges
|
||||
regression-tests:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [forge]
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 45
|
||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -126,51 +140,23 @@ jobs:
|
||||
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run regression tests
|
||||
working-directory: classic
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd ${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||
if [ $EXIT_CODE -eq 5 ]; then
|
||||
echo "regression_tests.json is empty."
|
||||
fi
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock"
|
||||
poetry run agbenchmark --mock
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||
poetry run agbenchmark --mock --category=data
|
||||
|
||||
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||
poetry run agbenchmark --mock --category=coding
|
||||
|
||||
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||
# poetry run agbenchmark --test=WriteFile
|
||||
cd ../benchmark
|
||||
poetry install
|
||||
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||
export BUILD_SKILL_TREE=true
|
||||
|
||||
# poetry run agbenchmark --mock
|
||||
|
||||
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||
# if [ ! -z "$CHANGED" ]; then
|
||||
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
# echo "$CHANGED"
|
||||
# exit 1
|
||||
# else
|
||||
# echo "No unstaged changes."
|
||||
# fi
|
||||
echo "Running regression tests (previously beaten challenges)..."
|
||||
poetry run direct-benchmark run \
|
||||
--fresh \
|
||||
--strategies one_shot \
|
||||
--models claude \
|
||||
--maintain \
|
||||
--parallel 4 \
|
||||
--json
|
||||
env:
|
||||
CI: true
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
NONINTERACTIVE_MODE: "true"
|
||||
|
||||
182
.github/workflows/classic-forge-ci.yml
vendored
182
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,13 +6,11 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
branches: [ master, dev, release-* ]
|
||||
paths:
|
||||
- '.github/workflows/classic-forge-ci.yml'
|
||||
- 'classic/forge/**'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
@@ -21,115 +19,38 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic/forge
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
- name: Start MinIO service
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Checkout cassettes
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
env:
|
||||
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
cassette_base_branch="${PR_BASE}"
|
||||
cd tests/vcr_cassettes
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||
cassette_base_branch="master"
|
||||
fi
|
||||
|
||||
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||
git fetch origin $cassette_branch
|
||||
git fetch origin $cassette_base_branch
|
||||
|
||||
git checkout $cassette_branch
|
||||
|
||||
# Pick non-conflicting cassette updates from the base branch
|
||||
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||
"synced to upstream branch '$cassette_base_branch'."
|
||||
else
|
||||
git checkout -b $cassette_branch
|
||||
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||
"Using cassettes from '$cassette_base_branch'."
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
@@ -140,12 +61,15 @@ jobs:
|
||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--durations=10 \
|
||||
--junitxml=junit.xml -o junit_family=legacy \
|
||||
forge
|
||||
forge/forge forge/tests
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
# API keys - tests that need these will skip if not available
|
||||
# Secrets are not available to fork PRs (GitHub security feature)
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
@@ -159,85 +83,11 @@ jobs:
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: forge,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
# Cassettes may be pushed even when tests fail
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
|
||||
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||
|
||||
- id: push_cassettes
|
||||
name: Push updated cassettes
|
||||
# For pull requests, push updated cassettes even when tests fail
|
||||
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||
env:
|
||||
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
run: |
|
||||
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||
is_pull_request=true
|
||||
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||
else
|
||||
cassette_branch="${{ github.ref_name }}"
|
||||
fi
|
||||
|
||||
cd tests/vcr_cassettes
|
||||
# Commit & push changes to cassettes if any
|
||||
if ! git diff --quiet; then
|
||||
git add .
|
||||
git commit -m "Auto-update cassettes"
|
||||
git push origin HEAD:$cassette_branch
|
||||
if [ ! $is_pull_request ]; then
|
||||
cd ../..
|
||||
git add tests/vcr_cassettes
|
||||
git commit -m "Update cassette submodule"
|
||||
git push origin HEAD:$cassette_branch
|
||||
fi
|
||||
echo "updated=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "updated=false" >> $GITHUB_OUTPUT
|
||||
echo "No cassette changes to commit"
|
||||
fi
|
||||
|
||||
- name: Post Set up git token auth
|
||||
if: steps.setup_git_auth.outcome == 'success'
|
||||
run: |
|
||||
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||
|
||||
- name: Apply "behaviour change" label and comment on PR
|
||||
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
run: |
|
||||
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
flags: forge
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: classic/forge/logs/
|
||||
path: classic/logs/
|
||||
|
||||
60
.github/workflows/classic-frontend-ci.yml
vendored
60
.github/workflows/classic-frontend-ci.yml
vendored
@@ -1,60 +0,0 @@
|
||||
name: Classic - Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- dev
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'classic/frontend/**'
|
||||
- '.github/workflows/classic-frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('classic-frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd classic/frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add classic/frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
add-paths: classic/frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,7 +7,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
pull_request:
|
||||
@@ -16,7 +18,9 @@ on:
|
||||
- '.github/workflows/classic-python-checks-ci.yml'
|
||||
- 'classic/original_autogpt/**'
|
||||
- 'classic/forge/**'
|
||||
- 'classic/benchmark/**'
|
||||
- 'classic/direct_benchmark/**'
|
||||
- 'classic/pyproject.toml'
|
||||
- 'classic/poetry.lock'
|
||||
- '**.py'
|
||||
- '!classic/forge/tests/vcr_cassettes'
|
||||
|
||||
@@ -27,44 +31,13 @@ concurrency:
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: classic
|
||||
|
||||
jobs:
|
||||
get-changed-parts:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- id: changes-in
|
||||
name: Determine affected subprojects
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
original_autogpt:
|
||||
- classic/original_autogpt/autogpt/**
|
||||
- classic/original_autogpt/tests/**
|
||||
- classic/original_autogpt/poetry.lock
|
||||
forge:
|
||||
- classic/forge/forge/**
|
||||
- classic/forge/tests/**
|
||||
- classic/forge/poetry.lock
|
||||
benchmark:
|
||||
- classic/benchmark/agbenchmark/**
|
||||
- classic/benchmark/tests/**
|
||||
- classic/benchmark/poetry.lock
|
||||
outputs:
|
||||
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||
|
||||
lint:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -81,42 +54,31 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Lint
|
||||
|
||||
- name: Lint (isort)
|
||||
run: poetry run isort --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Black)
|
||||
if: success() || failure()
|
||||
run: poetry run black --check .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
- name: Lint (Flake8)
|
||||
if: success() || failure()
|
||||
run: poetry run flake8 .
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
types:
|
||||
needs: get-changed-parts
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
min-python-version: "3.10"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||
fail-fast: false
|
||||
min-python-version: "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -133,19 +95,16 @@ jobs:
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Install dependencies
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||
run: poetry install
|
||||
|
||||
# Typecheck
|
||||
|
||||
- name: Typecheck
|
||||
if: success() || failure()
|
||||
run: poetry run pyright
|
||||
working-directory: classic/${{ matrix.sub-package }}
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -3,6 +3,7 @@
|
||||
classic/original_autogpt/keys.py
|
||||
classic/original_autogpt/*.json
|
||||
auto_gpt_workspace/*
|
||||
.autogpt/
|
||||
*.mpeg
|
||||
.env
|
||||
# Root .env files
|
||||
@@ -159,6 +160,10 @@ CURRENT_BULLETIN.md
|
||||
|
||||
# AgBenchmark
|
||||
classic/benchmark/agbenchmark/reports/
|
||||
classic/reports/
|
||||
classic/direct_benchmark/reports/
|
||||
classic/.benchmark_workspaces/
|
||||
classic/direct_benchmark/.benchmark_workspaces/
|
||||
|
||||
# Nodejs
|
||||
package-lock.json
|
||||
@@ -177,5 +182,10 @@ autogpt_platform/backend/settings.py
|
||||
|
||||
*.ign.*
|
||||
.test-contents
|
||||
**/.claude/settings.local.json
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
|
||||
# Test database
|
||||
test.db
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "classic/forge/tests/vcr_cassettes"]
|
||||
path = classic/forge/tests/vcr_cassettes
|
||||
url = https://github.com/Significant-Gravitas/Auto-GPT-test-cassettes
|
||||
@@ -43,29 +43,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - AutoGPT
|
||||
alias: poetry-install-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt install
|
||||
# include forge source (since it's a path dependency)
|
||||
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Forge
|
||||
alias: poetry-install-classic-forge
|
||||
entry: poetry -C classic/forge install
|
||||
files: ^classic/forge/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: poetry-install
|
||||
name: Check & Install dependencies - Classic - Benchmark
|
||||
alias: poetry-install-classic-benchmark
|
||||
entry: poetry -C classic/benchmark install
|
||||
files: ^classic/benchmark/poetry\.lock$
|
||||
name: Check & Install dependencies - Classic
|
||||
alias: poetry-install-classic
|
||||
entry: poetry -C classic install
|
||||
files: ^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
@@ -116,26 +97,10 @@ repos:
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
name: Lint (isort) - Classic
|
||||
alias: isort-classic
|
||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
types: [file, python]
|
||||
language: system
|
||||
|
||||
@@ -149,26 +114,13 @@ repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
# To have flake8 load the config of the individual subprojects, we have to call
|
||||
# them separately.
|
||||
# Use consolidated flake8 config at classic/.flake8
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - AutoGPT
|
||||
alias: flake8-classic-autogpt
|
||||
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||
args: [--config=classic/original_autogpt/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Forge
|
||||
alias: flake8-classic-forge
|
||||
files: ^classic/forge/(forge|tests)/
|
||||
args: [--config=classic/forge/.flake8]
|
||||
|
||||
- id: flake8
|
||||
name: Lint (Flake8) - Classic - Benchmark
|
||||
alias: flake8-classic-benchmark
|
||||
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||
args: [--config=classic/benchmark/.flake8]
|
||||
name: Lint (Flake8) - Classic
|
||||
alias: flake8-classic
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
||||
args: [--config=classic/.flake8]
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -204,29 +156,10 @@ repos:
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
- id: pyright
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
name: Typecheck - Classic
|
||||
alias: pyright-classic
|
||||
entry: poetry -C classic run pyright
|
||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
||||
types: [file]
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
24
AGENTS.md
24
AGENTS.md
@@ -16,7 +16,6 @@ See `docs/content/platform/getting-started.md` for setup instructions.
|
||||
- Format Python code with `poetry run format`.
|
||||
- Format frontend code using `pnpm format`.
|
||||
|
||||
|
||||
## Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
@@ -33,14 +32,17 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any`, if not types available use `unknown`
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -49,22 +51,8 @@ See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
Always run the relevant linters and tests before committing.
|
||||
Use conventional commit messages for all commits (e.g. `feat(backend): add API`).
|
||||
Types:
|
||||
- feat
|
||||
- fix
|
||||
- refactor
|
||||
- ci
|
||||
- dx (developer experience)
|
||||
Scopes:
|
||||
- platform
|
||||
- platform/library
|
||||
- platform/marketplace
|
||||
- backend
|
||||
- backend/executor
|
||||
- frontend
|
||||
- frontend/library
|
||||
- frontend/marketplace
|
||||
- blocks
|
||||
Types: - feat - fix - refactor - ci - dx (developer experience)
|
||||
Scopes: - platform - platform/library - platform/marketplace - backend - backend/executor - frontend - frontend/library - frontend/marketplace - blocks
|
||||
|
||||
## Pull requests
|
||||
|
||||
|
||||
@@ -6,152 +6,30 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
AutoGPT Platform is a monorepo containing:
|
||||
|
||||
- **Backend** (`/backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`/frontend`): Next.js React application
|
||||
- **Shared Libraries** (`/autogpt_libs`): Common Python utilities
|
||||
- **Backend** (`backend`): Python FastAPI server with async support
|
||||
- **Frontend** (`frontend`): Next.js React application
|
||||
- **Shared Libraries** (`autogpt_libs`): Common Python utilities
|
||||
|
||||
## Essential Commands
|
||||
## Component Documentation
|
||||
|
||||
### Backend Development
|
||||
- **Backend**: See @backend/CLAUDE.md for backend-specific commands, architecture, and development tasks
|
||||
- **Frontend**: See @frontend/CLAUDE.md for frontend-specific commands, architecture, and development patterns
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd backend && poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend server
|
||||
poetry run serve
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in TESTING.md
|
||||
|
||||
#### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
cd frontend && pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
**📖 Complete Guide**: See `/frontend/CONTRIBUTING.md` and `/frontend/.cursorrules` for comprehensive frontend patterns.
|
||||
|
||||
**Key Frontend Conventions:**
|
||||
|
||||
- Separate render logic from data/behavior in components
|
||||
- Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Only use Phosphor Icons
|
||||
- Never use `src/components/__legacy__/*` or deprecated `BackendAPI`
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Backend Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
### Frontend Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
### Key Concepts
|
||||
## Key Concepts
|
||||
|
||||
1. **Agent Graphs**: Workflow definitions stored as JSON, executed by the backend
|
||||
2. **Blocks**: Reusable components in `/backend/blocks/` that perform specific tasks
|
||||
2. **Blocks**: Reusable components in `backend/backend/blocks/` that perform specific tasks
|
||||
3. **Integrations**: OAuth and API connections stored per user
|
||||
4. **Store**: Marketplace for sharing agent templates
|
||||
5. **Virus Scanning**: ClamAV integration for file upload security
|
||||
|
||||
### Testing Approach
|
||||
|
||||
- Backend uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
- Frontend uses Playwright for E2E tests
|
||||
- Component testing via Storybook
|
||||
|
||||
### Database Schema
|
||||
|
||||
Key models (defined in `/backend/schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Configuration Files
|
||||
|
||||
- **Backend**: `/backend/.env.default` (defaults) → `/backend/.env` (user overrides)
|
||||
- **Frontend**: `/frontend/.env.default` (defaults) → `/frontend/.env` (user overrides)
|
||||
- **Platform**: `/.env.default` (Supabase/shared defaults) → `/.env` (user overrides)
|
||||
- **Backend**: `backend/.env.default` (defaults) → `backend/.env` (user overrides)
|
||||
- **Frontend**: `frontend/.env.default` (defaults) → `frontend/.env` (user overrides)
|
||||
- **Platform**: `.env.default` (Supabase/shared defaults) → `.env` (user overrides)
|
||||
|
||||
#### Docker Environment Loading Order
|
||||
|
||||
@@ -167,83 +45,12 @@ Key models (defined in `/backend/schema.prisma`):
|
||||
- Backend/Frontend services use YAML anchors for consistent configuration
|
||||
- Supabase services (`db/docker/docker-compose.yml`) follow the same pattern
|
||||
|
||||
### Common Development Tasks
|
||||
|
||||
**Adding a new block:**
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](../../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `/backend/backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
**Modifying the API:**
|
||||
|
||||
1. Update route in `/backend/backend/server/routers/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
### Frontend guidelines:
|
||||
|
||||
See `/frontend/CONTRIBUTING.md` for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Add `usePageName.ts` hook for logic
|
||||
- Put sub-components in local `components/` folder
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**: Function declarations (not arrow functions) for components/handlers
|
||||
- Component props should be `interface Props { ... }` (not exported) unless the interface needs to be used outside the component
|
||||
- Separate render logic from business logic (component.tsx + useComponent.ts + helpers.ts)
|
||||
- Colocate state when possible and avoid creating large components, use sub-components ( local `/components` folder next to the parent component ) when sensible
|
||||
- Avoid large hooks, abstract logic into `helpers.ts` files when sensible
|
||||
- Use function declarations for components, arrow functions only for callbacks
|
||||
- No barrel files or `index.ts` re-exports
|
||||
- Do not use `useCallback` or `useMemo` unless strictly needed
|
||||
- Avoid comments at all times unless the code is very complex
|
||||
|
||||
### Security Implementation
|
||||
|
||||
**Cache Protection Middleware:**
|
||||
|
||||
- Located in `/backend/backend/server/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`/static/*`, `/_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
|
||||
### Creating Pull Requests
|
||||
|
||||
- Create the PR aginst the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)/
|
||||
- Use conventional commit messages (see below)/
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description/
|
||||
- Create the PR against the `dev` branch of the repository.
|
||||
- Ensure the branch name is descriptive (e.g., `feature/add-new-block`)
|
||||
- Use conventional commit messages (see below)
|
||||
- Fill out the .github/PULL_REQUEST_TEMPLATE.md template as the PR description
|
||||
- Run the github pre-commit hooks to ensure code quality.
|
||||
|
||||
### Reviewing/Revising Pull Requests
|
||||
|
||||
170
autogpt_platform/backend/CLAUDE.md
Normal file
170
autogpt_platform/backend/CLAUDE.md
Normal file
@@ -0,0 +1,170 @@
|
||||
# CLAUDE.md - Backend
|
||||
|
||||
This file provides guidance to Claude Code when working with the backend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
To run something with Python package dependencies you MUST use `poetry run ...`.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
poetry install
|
||||
|
||||
# Run database migrations
|
||||
poetry run prisma migrate dev
|
||||
|
||||
# Start all services (database, redis, rabbitmq, clamav)
|
||||
docker compose up -d
|
||||
|
||||
# Run the backend as a whole
|
||||
poetry run app
|
||||
|
||||
# Run tests
|
||||
poetry run test
|
||||
|
||||
# Run specific test
|
||||
poetry run pytest path/to/test_file.py::test_function_name
|
||||
|
||||
# Run block tests (tests that validate all blocks work correctly)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs
|
||||
|
||||
# Run tests for a specific block (e.g., GetCurrentTimeBlock)
|
||||
poetry run pytest 'backend/blocks/test/test_block.py::test_available_blocks[GetCurrentTimeBlock]' -xvs
|
||||
|
||||
# Lint and format
|
||||
# prefer format if you want to just "fix" it and only get the errors that can't be autofixed
|
||||
poetry run format # Black + isort
|
||||
poetry run lint # ruff
|
||||
```
|
||||
|
||||
More details can be found in @TESTING.md
|
||||
|
||||
### Creating/Updating Snapshots
|
||||
|
||||
When you first write a test or when the expected output changes:
|
||||
|
||||
```bash
|
||||
poetry run pytest path/to/test.py --snapshot-update
|
||||
```
|
||||
|
||||
⚠️ **Important**: Always review snapshot changes before committing! Use `git diff` to verify the changes are expected.
|
||||
|
||||
## Architecture
|
||||
|
||||
- **API Layer**: FastAPI with REST and WebSocket endpoints
|
||||
- **Database**: PostgreSQL with Prisma ORM, includes pgvector for embeddings
|
||||
- **Queue System**: RabbitMQ for async task processing
|
||||
- **Execution Engine**: Separate executor service processes agent workflows
|
||||
- **Authentication**: JWT-based with Supabase integration
|
||||
- **Security**: Cache protection middleware prevents sensitive data caching in browsers/proxies
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- Uses pytest with snapshot testing for API responses
|
||||
- Test files are colocated with source files (`*_test.py`)
|
||||
|
||||
## Database Schema
|
||||
|
||||
Key models (defined in `schema.prisma`):
|
||||
|
||||
- `User`: Authentication and profile data
|
||||
- `AgentGraph`: Workflow definitions with version control
|
||||
- `AgentGraphExecution`: Execution history and results
|
||||
- `AgentNode`: Individual nodes in a workflow
|
||||
- `StoreListing`: Marketplace listings for sharing agents
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
- **Backend**: `.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a new block
|
||||
|
||||
Follow the comprehensive [Block SDK Guide](@../../docs/content/platform/block-sdk-guide.md) which covers:
|
||||
|
||||
- Provider configuration with `ProviderBuilder`
|
||||
- Block schema definition
|
||||
- Authentication (API keys, OAuth, webhooks)
|
||||
- Testing and validation
|
||||
- File organization
|
||||
|
||||
Quick steps:
|
||||
|
||||
1. Create new file in `backend/blocks/`
|
||||
2. Configure provider using `ProviderBuilder` in `_config.py`
|
||||
3. Inherit from `Block` base class
|
||||
4. Define input/output schemas using `BlockSchema`
|
||||
5. Implement async `run` method
|
||||
6. Generate unique block ID using `uuid.uuid4()`
|
||||
7. Test with `poetry run pytest backend/blocks/test/test_block.py`
|
||||
|
||||
Note: when making many new blocks analyze the interfaces for each of these blocks and picture if they would go well together in a graph-based editor or would they struggle to connect productively?
|
||||
ex: do the inputs and outputs tie well together?
|
||||
|
||||
If you get any pushback or hit complex block conditions check the new_blocks guide in the docs.
|
||||
|
||||
#### Handling files in blocks with `store_media_file()`
|
||||
|
||||
When blocks need to work with files (images, videos, documents), use `store_media_file()` from `backend.util.file`. The `return_format` parameter determines what you get back:
|
||||
|
||||
| Format | Use When | Returns |
|
||||
|--------|----------|---------|
|
||||
| `"for_local_processing"` | Processing with local tools (ffmpeg, MoviePy, PIL) | Local file path (e.g., `"image.png"`) |
|
||||
| `"for_external_api"` | Sending content to external APIs (Replicate, OpenAI) | Data URI (e.g., `"data:image/png;base64,..."`) |
|
||||
| `"for_block_output"` | Returning output from your block | Smart: `workspace://` in CoPilot, data URI in graphs |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
# INPUT: Need to process file locally with ffmpeg
|
||||
local_path = await store_media_file(
|
||||
file=input_data.video,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
# local_path = "video.mp4" - use with Path/ffmpeg/etc
|
||||
|
||||
# INPUT: Need to send to external API like Replicate
|
||||
image_b64 = await store_media_file(
|
||||
file=input_data.image,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
# image_b64 = "..." - send to API
|
||||
|
||||
# OUTPUT: Returning result from block
|
||||
result_url = await store_media_file(
|
||||
file=generated_image_url,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", result_url
|
||||
# In CoPilot: result_url = "workspace://abc123"
|
||||
# In graphs: result_url = "data:image/png;base64,..."
|
||||
```
|
||||
|
||||
**Key points:**
|
||||
|
||||
- `for_block_output` is the ONLY format that auto-adapts to execution context
|
||||
- Always use `for_block_output` for block outputs unless you have a specific reason not to
|
||||
- Never hardcode workspace checks - let `for_block_output` handle it
|
||||
|
||||
### Modifying the API
|
||||
|
||||
1. Update route in `backend/api/features/`
|
||||
2. Add/update Pydantic models in same directory
|
||||
3. Write tests alongside the route file
|
||||
4. Run `poetry run test` to verify
|
||||
|
||||
## Security Implementation
|
||||
|
||||
### Cache Protection Middleware
|
||||
|
||||
- Located in `backend/api/middleware/security.py`
|
||||
- Default behavior: Disables caching for ALL endpoints with `Cache-Control: no-store, no-cache, must-revalidate, private`
|
||||
- Uses an allow list approach - only explicitly permitted paths can be cached
|
||||
- Cacheable paths include: static assets (`static/*`, `_next/static/*`), health checks, public store pages, documentation
|
||||
- Prevents sensitive data (auth tokens, API keys, user data) from being cached by browsers/proxies
|
||||
- To allow caching for a new endpoint, add it to `CACHEABLE_PATHS` in the middleware
|
||||
- Applied to both main API server and external API applications
|
||||
@@ -138,7 +138,7 @@ If the test doesn't need the `user_id` specifically, mocking is not necessary as
|
||||
|
||||
#### Using Global Auth Fixtures
|
||||
|
||||
Two global auth fixtures are provided by `backend/server/conftest.py`:
|
||||
Two global auth fixtures are provided by `backend/api/conftest.py`:
|
||||
|
||||
- `mock_jwt_user` - Regular user with `test_user_id` ("test-user-id")
|
||||
- `mock_jwt_admin` - Admin user with `admin_user_id` ("admin-user-id")
|
||||
|
||||
@@ -17,7 +17,7 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# Taken from backend/server/v2/store/db.py
|
||||
# Taken from backend/api/features/store/db.py
|
||||
def sanitize_query(query: str | None) -> str | None:
|
||||
if query is None:
|
||||
return query
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# CoPilot Tools - Future Ideas
|
||||
|
||||
## Multimodal Image Support for CoPilot
|
||||
|
||||
**Problem:** CoPilot uses a vision-capable model but can't "see" workspace images. When a block generates an image and returns `workspace://abc123`, CoPilot can't evaluate it (e.g., checking blog thumbnail quality).
|
||||
|
||||
**Backend Solution:**
|
||||
When preparing messages for the LLM, detect `workspace://` image references and convert them to proper image content blocks:
|
||||
|
||||
```python
|
||||
# Before sending to LLM, scan for workspace image references
|
||||
# and inject them as image content parts
|
||||
|
||||
# Example message transformation:
|
||||
# FROM: {"role": "assistant", "content": "Generated image: workspace://abc123"}
|
||||
# TO: {"role": "assistant", "content": [
|
||||
# {"type": "text", "text": "Generated image: workspace://abc123"},
|
||||
# {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||
# ]}
|
||||
```
|
||||
|
||||
**Where to implement:**
|
||||
- In the chat stream handler before calling the LLM
|
||||
- Or in a message preprocessing step
|
||||
- Need to fetch image from workspace, convert to base64, add as image content
|
||||
|
||||
**Considerations:**
|
||||
- Only do this for image MIME types (image/png, image/jpeg, etc.)
|
||||
- May want a size limit (don't pass 10MB images)
|
||||
- Track which images were "shown" to the AI for frontend indicator
|
||||
- Cost implications - vision API calls are more expensive
|
||||
|
||||
**Frontend Solution:**
|
||||
Show visual indicator on workspace files in chat:
|
||||
- If AI saw the image: normal display
|
||||
- If AI didn't see it: overlay icon saying "AI can't see this image"
|
||||
|
||||
Requires response metadata indicating which `workspace://` refs were passed to the model.
|
||||
|
||||
---
|
||||
|
||||
## Output Post-Processing Layer for run_block
|
||||
|
||||
**Problem:** Many blocks produce large outputs that:
|
||||
- Consume massive context (100KB base64 image = ~133KB tokens)
|
||||
- Can't fit in conversation
|
||||
- Break things and cause high LLM costs
|
||||
|
||||
**Proposed Solution:** Instead of modifying individual blocks or `store_media_file()`, implement a centralized output processor in `run_block.py` that handles outputs before they're returned to CoPilot.
|
||||
|
||||
**Benefits:**
|
||||
1. **Centralized** - one place to handle all output processing
|
||||
2. **Future-proof** - new blocks automatically get output processing
|
||||
3. **Keeps blocks pure** - they don't need to know about context constraints
|
||||
4. **Handles all large outputs** - not just images
|
||||
|
||||
**Processing Rules:**
|
||||
- Detect base64 data URIs → save to workspace, return `workspace://` reference
|
||||
- Truncate very long strings (>N chars) with truncation note
|
||||
- Summarize large arrays/lists (e.g., "Array with 1000 items, first 5: [...]")
|
||||
- Handle nested large outputs in dicts recursively
|
||||
- Cap total output size
|
||||
|
||||
**Implementation Location:** `run_block.py` after block execution, before returning `BlockOutputResponse`
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def _process_outputs_for_context(
|
||||
outputs: dict[str, list[Any]],
|
||||
workspace_manager: WorkspaceManager,
|
||||
max_string_length: int = 10000,
|
||||
max_array_preview: int = 5,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""Process block outputs to prevent context bloat."""
|
||||
processed = {}
|
||||
for name, values in outputs.items():
|
||||
processed[name] = [_process_value(v, workspace_manager) for v in values]
|
||||
return processed
|
||||
```
|
||||
@@ -18,6 +18,12 @@ from .get_doc_page import GetDocPageTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .run_block import RunBlockTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WriteWorkspaceFileTool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
@@ -37,6 +43,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
"write_workspace_file": WriteWorkspaceFileTool(),
|
||||
"delete_workspace_file": DeleteWorkspaceFileTool(),
|
||||
}
|
||||
|
||||
# Export individual tool instances for backwards compatibility
|
||||
|
||||
@@ -9,6 +9,7 @@ from .core import (
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
@@ -25,4 +26,6 @@ __all__ = [
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
# Error handling
|
||||
"get_user_message_for_error",
|
||||
]
|
||||
|
||||
@@ -64,7 +64,7 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -73,7 +73,10 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions)
|
||||
if result:
|
||||
# Ensure required fields
|
||||
# Check if it's an error response - pass through as-is
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
# Ensure required fields for successful agent generation
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
@@ -267,7 +270,8 @@ async def generate_agent_patch(
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Error handling utilities for agent generator."""
|
||||
|
||||
|
||||
def get_user_message_for_error(
|
||||
error_type: str,
|
||||
operation: str = "process the request",
|
||||
llm_parse_message: str | None = None,
|
||||
validation_message: str | None = None,
|
||||
) -> str:
|
||||
"""Get a user-friendly error message based on error type.
|
||||
|
||||
This function maps internal error types to user-friendly messages,
|
||||
providing a consistent experience across different agent operations.
|
||||
|
||||
Args:
|
||||
error_type: The error type from the external service
|
||||
(e.g., "llm_parse_error", "timeout", "rate_limit")
|
||||
operation: Description of what operation failed, used in the default
|
||||
message (e.g., "analyze the goal", "generate the agent")
|
||||
llm_parse_message: Custom message for llm_parse_error type
|
||||
validation_message: Custom message for validation_error type
|
||||
|
||||
Returns:
|
||||
User-friendly error message suitable for display to the user
|
||||
"""
|
||||
if error_type == "llm_parse_error":
|
||||
return (
|
||||
llm_parse_message
|
||||
or "The AI had trouble processing this request. Please try again."
|
||||
)
|
||||
elif error_type == "validation_error":
|
||||
return (
|
||||
validation_message
|
||||
or "The request failed validation. Please try rephrasing."
|
||||
)
|
||||
elif error_type == "patch_error":
|
||||
return "Failed to apply the changes. Please try a different approach."
|
||||
elif error_type in ("timeout", "llm_timeout"):
|
||||
return "The request took too long. Please try again."
|
||||
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||
return "The service is currently busy. Please try again in a moment."
|
||||
else:
|
||||
return f"Failed to {operation}. Please try again."
|
||||
@@ -14,6 +14,70 @@ from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_error_response(
|
||||
error_message: str,
|
||||
error_type: str = "unknown",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a standardized error response dict.
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error_type: Machine-readable error type
|
||||
details: Optional additional error details
|
||||
|
||||
Returns:
|
||||
Error dict with type="error" and error details
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
"type": "error",
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
}
|
||||
if details:
|
||||
response["details"] = details
|
||||
return response
|
||||
|
||||
|
||||
def _classify_http_error(e: httpx.HTTPStatusError) -> tuple[str, str]:
|
||||
"""Classify an HTTP error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The HTTP status error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
status = e.response.status_code
|
||||
if status == 429:
|
||||
return "rate_limit", f"Agent Generator rate limited: {e}"
|
||||
elif status == 503:
|
||||
return "service_unavailable", f"Agent Generator unavailable: {e}"
|
||||
elif status == 504 or status == 408:
|
||||
return "timeout", f"Agent Generator timed out: {e}"
|
||||
else:
|
||||
return "http_error", f"HTTP error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
def _classify_request_error(e: httpx.RequestError) -> tuple[str, str]:
|
||||
"""Classify a request error into error_type and message.
|
||||
|
||||
Args:
|
||||
e: The request error
|
||||
|
||||
Returns:
|
||||
Tuple of (error_type, error_message)
|
||||
"""
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str or "timed out" in error_str:
|
||||
return "timeout", f"Agent Generator request timed out: {e}"
|
||||
elif "connect" in error_str:
|
||||
return "connection_error", f"Could not connect to Agent Generator: {e}"
|
||||
else:
|
||||
return "request_error", f"Request error calling Agent Generator: {e}"
|
||||
|
||||
|
||||
_client: httpx.AsyncClient | None = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
@@ -67,7 +131,8 @@ async def decompose_goal_external(
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
- {"type": "unachievable_goal", ...}
|
||||
- {"type": "vague_goal", ...}
|
||||
Or None on error
|
||||
- {"type": "error", "error": "...", "error_type": "..."} on error
|
||||
Or None on unexpected error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
@@ -83,8 +148,13 @@ async def decompose_goal_external(
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator decomposition failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Map the response to the expected format
|
||||
response_type = data.get("type")
|
||||
@@ -106,25 +176,37 @@ async def decompose_goal_external(
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": data.get("suggested_goal"),
|
||||
}
|
||||
elif response_type == "error":
|
||||
# Pass through error from the service
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unknown response type from external service: {response_type}"
|
||||
)
|
||||
return None
|
||||
return _create_error_response(
|
||||
f"Unknown response type from Agent Generator: {response_type}",
|
||||
"invalid_response",
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any]
|
||||
instructions: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
@@ -132,7 +214,7 @@ async def generate_agent_external(
|
||||
instructions: Structured instructions from decompose_goal
|
||||
|
||||
Returns:
|
||||
Agent JSON dict or None on error
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
@@ -144,20 +226,28 @@ async def generate_agent_external(
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
@@ -170,7 +260,7 @@ async def generate_agent_patch_external(
|
||||
current_agent: Current agent JSON
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or None on error
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
@@ -186,8 +276,13 @@ async def generate_agent_patch_external(
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
logger.error(f"External service returned error: {data.get('error')}")
|
||||
return None
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator patch generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
@@ -196,18 +291,28 @@ async def generate_agent_patch_external(
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the updated agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error calling external agent generator: {e}")
|
||||
return None
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error calling external agent generator: {e}")
|
||||
return None
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
generate_agent,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
@@ -117,11 +118,29 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
if decomposition_result is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="analyze the goal",
|
||||
llm_parse_message="The AI had trouble understanding this request. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100]
|
||||
}, # Include context for debugging
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -186,11 +205,30 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
if agent_json is None:
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100]
|
||||
}, # Include context for debugging
|
||||
"description": description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
@@ -152,6 +153,28 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
|
||||
@@ -28,6 +28,12 @@ class ResponseType(str, Enum):
|
||||
BLOCK_OUTPUT = "block_output"
|
||||
DOC_SEARCH_RESULTS = "doc_search_results"
|
||||
DOC_PAGE = "doc_page"
|
||||
# Workspace response types
|
||||
WORKSPACE_FILE_LIST = "workspace_file_list"
|
||||
WORKSPACE_FILE_CONTENT = "workspace_file_content"
|
||||
WORKSPACE_FILE_METADATA = "workspace_file_metadata"
|
||||
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
|
||||
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
|
||||
# Long-running operation types
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tool for executing blocks directly."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
@@ -8,6 +9,7 @@ from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
@@ -223,11 +225,48 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch actual credentials and prepare kwargs for block execution
|
||||
# Create execution context with defaults (blocks may require it)
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
# This means:
|
||||
# - graph_id (agent) = session (memories scoped to session when limit_to_agent=True)
|
||||
# - graph_exec_id (run) = session (memories scoped to session when limit_to_run=True)
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
# Create unified execution context with all required fields
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=synthetic_graph_id,
|
||||
graph_exec_id=synthetic_graph_exec_id,
|
||||
graph_version=1, # Versions are 1-indexed
|
||||
node_id=synthetic_node_id,
|
||||
node_exec_id=synthetic_node_exec_id,
|
||||
# Workspace with session scoping
|
||||
workspace_id=workspace.id,
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Prepare kwargs for block execution
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
exec_kwargs: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"execution_context": ExecutionContext(),
|
||||
"execution_context": execution_context,
|
||||
# Legacy: individual kwargs for blocks not yet using execution_context
|
||||
"workspace_id": workspace.id,
|
||||
"graph_exec_id": synthetic_graph_exec_id,
|
||||
"node_exec_id": synthetic_node_exec_id,
|
||||
"node_id": synthetic_node_id,
|
||||
"graph_version": 1, # Versions are 1-indexed
|
||||
"graph_id": synthetic_graph_id,
|
||||
}
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's workspace. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's workspace. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's workspace. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# Workspace API feature module
|
||||
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Workspace API routes for managing user file storage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Annotated
|
||||
from urllib.parse import quote
|
||||
|
||||
import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi.responses import Response
|
||||
|
||||
from backend.data.workspace import get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
def _sanitize_filename_for_header(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename for Content-Disposition header to prevent header injection.
|
||||
|
||||
Removes/replaces characters that could break the header or inject new headers.
|
||||
Uses RFC5987 encoding for non-ASCII characters.
|
||||
"""
|
||||
# Remove CR, LF, and null bytes (header injection prevention)
|
||||
sanitized = re.sub(r"[\r\n\x00]", "", filename)
|
||||
# Escape quotes
|
||||
sanitized = sanitized.replace('"', '\\"')
|
||||
# For non-ASCII, use RFC5987 filename* parameter
|
||||
# Check if filename has non-ASCII characters
|
||||
try:
|
||||
sanitized.encode("ascii")
|
||||
return f'attachment; filename="{sanitized}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC5987 encoding for UTF-8 filenames
|
||||
encoded = quote(sanitized, safe="")
|
||||
return f"attachment; filename*=UTF-8''{encoded}"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
dependencies=[fastapi.Security(requires_user)],
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mimeType,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
Handles both local storage (direct streaming) and GCS (signed URL redirect
|
||||
with fallback to streaming).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
|
||||
# For local storage, stream the file directly
|
||||
if file.storagePath.startswith("local://"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
logger.error(
|
||||
f"Failed to get signed URL for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/files/{file_id}/download",
|
||||
summary="Download file by ID",
|
||||
)
|
||||
async def download_file(
|
||||
user_id: Annotated[str, fastapi.Security(get_user_id)],
|
||||
file_id: str,
|
||||
) -> Response:
|
||||
"""
|
||||
Download a file by its ID.
|
||||
|
||||
Returns the file content directly or redirects to a signed URL for GCS.
|
||||
"""
|
||||
workspace = await get_workspace(user_id)
|
||||
if workspace is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
|
||||
|
||||
file = await get_workspace_file(file_id, workspace.id)
|
||||
if file is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return await _create_file_download_response(file)
|
||||
@@ -32,6 +32,7 @@ import backend.api.features.postmark.postmark
|
||||
import backend.api.features.store.model
|
||||
import backend.api.features.store.routes
|
||||
import backend.api.features.v1
|
||||
import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
@@ -52,6 +53,7 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
from .external.fastapi_app import external_api
|
||||
from .features.analytics import router as analytics_router
|
||||
@@ -124,6 +126,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down cloud storage handler: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_workspace_storage()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error shutting down workspace storage: {e}")
|
||||
|
||||
await backend.data.db.disconnect()
|
||||
|
||||
|
||||
@@ -315,6 +322,11 @@ app.include_router(
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/chat",
|
||||
)
|
||||
app.include_router(
|
||||
workspace_routes.router,
|
||||
tags=["workspace"],
|
||||
prefix="/api/workspace",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.oauth.router,
|
||||
tags=["oauth"],
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -117,11 +118,13 @@ class AIImageCustomizerBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("image_url", "https://replicate.delivery/generated-image.jpg"),
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: MediaFileType(
|
||||
"https://replicate.delivery/generated-image.jpg"
|
||||
""
|
||||
),
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -132,8 +135,7 @@ class AIImageCustomizerBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
@@ -141,10 +143,9 @@ class AIImageCustomizerBlock(Block):
|
||||
processed_images = await asyncio.gather(
|
||||
*(
|
||||
store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=img,
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content for Replicate API
|
||||
)
|
||||
for img in input_data.images
|
||||
)
|
||||
@@ -158,7 +159,14 @@ class AIImageCustomizerBlock(Block):
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
output_format=input_data.output_format.value,
|
||||
)
|
||||
yield "image_url", result
|
||||
|
||||
# Store the generated image to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=result,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -13,6 +14,8 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
@@ -165,11 +168,13 @@ class AIImageGeneratorBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"image_url",
|
||||
"https://replicate.delivery/generated-image.webp",
|
||||
# Test output is a data URI since we now store images
|
||||
lambda x: x.startswith(""
|
||||
},
|
||||
)
|
||||
|
||||
@@ -318,11 +323,24 @@ class AIImageGeneratorBlock(Block):
|
||||
style_text = style_map.get(style, "")
|
||||
return f"{style_text} of" if style_text else ""
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
url = await self.generate_image(input_data, credentials)
|
||||
if url:
|
||||
yield "image_url", url
|
||||
# Store the generated image to the user's workspace/execution folder
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
else:
|
||||
yield "error", "Image generation returned an empty result."
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -21,7 +22,9 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -271,7 +274,10 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"voice": Voice.LILY,
|
||||
"video_style": VisualMediaType.STOCK_VIDEOS,
|
||||
},
|
||||
test_output=("video_url", "https://example.com/video.mp4"),
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -280,15 +286,21 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/video.mp4",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/video.mp4",
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create a new Webhook.site URL
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
@@ -340,7 +352,13 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
)
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
logger.debug(f"Video ready: {video_url}")
|
||||
yield "video_url", video_url
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
class AIAdMakerVideoCreatorBlock(Block):
|
||||
@@ -447,7 +465,10 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
"https://cdn.revid.ai/uploads/1747076315114-image.png",
|
||||
],
|
||||
},
|
||||
test_output=("video_url", "https://example.com/ad.mp4"),
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -456,14 +477,21 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/ad.mp4",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/ad.mp4",
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
@@ -531,7 +559,13 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
|
||||
|
||||
class AIScreenshotToVideoAdBlock(Block):
|
||||
@@ -626,7 +660,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
"script": "Amazing numbers!",
|
||||
"screenshot_url": "https://cdn.revid.ai/uploads/1747080376028-image.png",
|
||||
},
|
||||
test_output=("video_url", "https://example.com/screenshot.mp4"),
|
||||
test_output=(
|
||||
"video_url",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
test_mock={
|
||||
"create_webhook": lambda *args, **kwargs: (
|
||||
"test_uuid",
|
||||
@@ -635,14 +672,21 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
"create_video": lambda *args, **kwargs: {"pid": "test_pid"},
|
||||
"check_video_status": lambda *args, **kwargs: {
|
||||
"status": "ready",
|
||||
"videoUrl": "https://example.com/screenshot.mp4",
|
||||
"videoUrl": "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
"wait_for_video": lambda *args, **kwargs: "https://example.com/screenshot.mp4",
|
||||
"wait_for_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs):
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
):
|
||||
webhook_token, webhook_url = await self.create_webhook()
|
||||
|
||||
payload = {
|
||||
@@ -710,4 +754,10 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
raise RuntimeError("Failed to create video: No project ID returned")
|
||||
|
||||
video_url = await self.wait_for_video(credentials.api_key, pid)
|
||||
yield "video_url", video_url
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
|
||||
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -17,6 +18,8 @@ from backend.sdk import (
|
||||
Requests,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._config import bannerbear
|
||||
|
||||
@@ -135,15 +138,17 @@ class BannerbearTextOverlayBlock(Block):
|
||||
},
|
||||
test_output=[
|
||||
("success", True),
|
||||
("image_url", "https://cdn.bannerbear.com/test-image.jpg"),
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("image_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
("uid", "test-uid-123"),
|
||||
("status", "completed"),
|
||||
],
|
||||
test_mock={
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"_make_api_request": lambda *args, **kwargs: {
|
||||
"uid": "test-uid-123",
|
||||
"status": "completed",
|
||||
"image_url": "https://cdn.bannerbear.com/test-image.jpg",
|
||||
"image_url": "",
|
||||
}
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -177,7 +182,12 @@ class BannerbearTextOverlayBlock(Block):
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Build the modifications array
|
||||
modifications = []
|
||||
@@ -234,6 +244,18 @@ class BannerbearTextOverlayBlock(Block):
|
||||
|
||||
# Synchronous request - image should be ready
|
||||
yield "success", True
|
||||
yield "image_url", data.get("image_url", "")
|
||||
|
||||
# Store the generated image to workspace for persistence
|
||||
image_url = data.get("image_url", "")
|
||||
if image_url:
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(image_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "image_url", stored_url
|
||||
else:
|
||||
yield "image_url", ""
|
||||
|
||||
yield "uid", data.get("uid", "")
|
||||
yield "status", data.get("status", "completed")
|
||||
|
||||
@@ -9,6 +9,7 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType, convert
|
||||
@@ -17,10 +18,10 @@ from backend.util.type import MediaFileType, convert
|
||||
class FileStoreBlock(Block):
|
||||
class Input(BlockSchemaInput):
|
||||
file_in: MediaFileType = SchemaField(
|
||||
description="The file to store in the temporary directory, it can be a URL, data URI, or local path."
|
||||
description="The file to download and store. Can be a URL (https://...), data URI, or local path."
|
||||
)
|
||||
base_64: bool = SchemaField(
|
||||
description="Whether produce an output in base64 format (not recommended, you can pass the string path just fine accross blocks).",
|
||||
description="Whether to produce output in base64 format (not recommended, you can pass the file reference across blocks).",
|
||||
default=False,
|
||||
advanced=True,
|
||||
title="Produce Base64 Output",
|
||||
@@ -28,13 +29,18 @@ class FileStoreBlock(Block):
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
file_out: MediaFileType = SchemaField(
|
||||
description="The relative path to the stored file in the temporary directory."
|
||||
description="Reference to the stored file. In CoPilot: workspace:// URI (visible in list_workspace_files). In graphs: data URI for passing to other blocks."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="cbb50872-625b-42f0-8203-a2ae78242d8a",
|
||||
description="Stores the input file in the temporary directory.",
|
||||
description=(
|
||||
"Downloads and stores a file from a URL, data URI, or local path. "
|
||||
"Use this to fetch images, documents, or other files for processing. "
|
||||
"In CoPilot: saves to workspace (use list_workspace_files to see it). "
|
||||
"In graphs: outputs a data URI to pass to other blocks."
|
||||
),
|
||||
categories={BlockCategory.BASIC, BlockCategory.MULTIMEDIA},
|
||||
input_schema=FileStoreBlock.Input,
|
||||
output_schema=FileStoreBlock.Output,
|
||||
@@ -45,15 +51,18 @@ class FileStoreBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Determine return format based on user preference
|
||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||
|
||||
yield "file_out", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_in,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
execution_context=execution_context,
|
||||
return_format=return_format,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import APIKeyCredentials, SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
@@ -666,8 +667,7 @@ class SendDiscordFileBlock(Block):
|
||||
file: MediaFileType,
|
||||
filename: str,
|
||||
message_content: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
) -> dict:
|
||||
intents = discord.Intents.default()
|
||||
intents.guilds = True
|
||||
@@ -731,10 +731,9 @@ class SendDiscordFileBlock(Block):
|
||||
# Local file path - read from stored media file
|
||||
# This would be a path from a previous block's output
|
||||
stored_file = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=file,
|
||||
user_id=user_id,
|
||||
return_content=True, # Get as data URI
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content to send to Discord
|
||||
)
|
||||
# Now process as data URI
|
||||
header, encoded = stored_file.split(",", 1)
|
||||
@@ -781,8 +780,7 @@ class SendDiscordFileBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
@@ -793,8 +791,7 @@ class SendDiscordFileBlock(Block):
|
||||
file=input_data.file,
|
||||
filename=input_data.filename,
|
||||
message_content=input_data.message_content,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
yield "status", result.get("status", "Unknown error")
|
||||
|
||||
@@ -17,8 +17,11 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import ClientResponseError, Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,9 +67,13 @@ class AIVideoGeneratorBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
|
||||
test_output=[
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("video_url", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
],
|
||||
test_mock={
|
||||
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"generate_video": lambda *args, **kwargs: "data:video/mp4;base64,AAAA"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -208,11 +215,22 @@ class AIVideoGeneratorBlock(Block):
|
||||
raise RuntimeError(f"API request failed: {str(e)}")
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: FalCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: FalCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
video_url = await self.generate_video(input_data, credentials)
|
||||
yield "video_url", video_url
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
yield "error", error_message
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -121,10 +122,12 @@ class AIImageEditorBlock(Block):
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[
|
||||
("output_image", "https://replicate.com/output/edited-image.png"),
|
||||
# Output will be a workspace ref or data URI depending on context
|
||||
("output_image", lambda x: x.startswith(("workspace://", "data:"))),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda *args, **kwargs: "https://replicate.com/output/edited-image.png",
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"run_model": lambda *args, **kwargs: "",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -134,8 +137,7 @@ class AIImageEditorBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
result = await self.run_model(
|
||||
@@ -144,20 +146,25 @@ class AIImageEditorBlock(Block):
|
||||
prompt=input_data.prompt,
|
||||
input_image_b64=(
|
||||
await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.input_image,
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
execution_context=execution_context,
|
||||
return_format="for_external_api", # Get content for Replicate API
|
||||
)
|
||||
if input_data.input_image
|
||||
else None
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=execution_context.user_id or "",
|
||||
graph_exec_id=execution_context.graph_exec_id or "",
|
||||
)
|
||||
yield "output_image", result
|
||||
# Store the generated image to the user's workspace for persistence
|
||||
stored_url = await store_media_file(
|
||||
file=result,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "output_image", stored_url
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
from backend.util.settings import Settings
|
||||
@@ -95,8 +96,7 @@ def _make_mime_text(
|
||||
|
||||
async def create_mime_message(
|
||||
input_data,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
) -> str:
|
||||
"""Create a MIME message with attachments and return base64-encoded raw message."""
|
||||
|
||||
@@ -117,12 +117,12 @@ async def create_mime_message(
|
||||
if input_data.attachments:
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
@@ -582,27 +582,25 @@ class GmailSendBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._send_email(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
execution_context,
|
||||
)
|
||||
yield "result", result
|
||||
|
||||
async def _send_email(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
) -> dict:
|
||||
if not input_data.to or not input_data.subject or not input_data.body:
|
||||
raise ValueError(
|
||||
"At least one recipient, subject, and body are required for sending an email"
|
||||
)
|
||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||
raw_message = await create_mime_message(input_data, execution_context)
|
||||
sent_message = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.messages()
|
||||
@@ -692,30 +690,28 @@ class GmailCreateDraftBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._create_draft(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
execution_context,
|
||||
)
|
||||
yield "result", GmailDraftResult(
|
||||
id=result["id"], message_id=result["message"]["id"], status="draft_created"
|
||||
)
|
||||
|
||||
async def _create_draft(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
) -> dict:
|
||||
if not input_data.to or not input_data.subject:
|
||||
raise ValueError(
|
||||
"At least one recipient and subject are required for creating a draft"
|
||||
)
|
||||
|
||||
raw_message = await create_mime_message(input_data, graph_exec_id, user_id)
|
||||
raw_message = await create_mime_message(input_data, execution_context)
|
||||
draft = await asyncio.to_thread(
|
||||
lambda: service.users()
|
||||
.drafts()
|
||||
@@ -1100,7 +1096,7 @@ class GmailGetThreadBlock(GmailBase):
|
||||
|
||||
|
||||
async def _build_reply_message(
|
||||
service, input_data, graph_exec_id: str, user_id: str
|
||||
service, input_data, execution_context: ExecutionContext
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds a reply MIME message for Gmail threads.
|
||||
@@ -1190,12 +1186,12 @@ async def _build_reply_message(
|
||||
# Handle attachments
|
||||
for attach in input_data.attachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
@@ -1311,16 +1307,14 @@ class GmailReplyBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
message = await self._reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
execution_context,
|
||||
)
|
||||
yield "messageId", message["id"]
|
||||
yield "threadId", message.get("threadId", input_data.threadId)
|
||||
@@ -1343,11 +1337,11 @@ class GmailReplyBlock(GmailBase):
|
||||
yield "email", email
|
||||
|
||||
async def _reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
service, input_data, execution_context
|
||||
)
|
||||
|
||||
# Send the message
|
||||
@@ -1441,16 +1435,14 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
draft = await self._create_draft_reply(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
execution_context,
|
||||
)
|
||||
yield "draftId", draft["id"]
|
||||
yield "messageId", draft["message"]["id"]
|
||||
@@ -1458,11 +1450,11 @@ class GmailDraftReplyBlock(GmailBase):
|
||||
yield "status", "draft_created"
|
||||
|
||||
async def _create_draft_reply(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
) -> dict:
|
||||
# Build the reply message using the shared helper
|
||||
raw, thread_id = await _build_reply_message(
|
||||
service, input_data, graph_exec_id, user_id
|
||||
service, input_data, execution_context
|
||||
)
|
||||
|
||||
# Create draft with proper thread association
|
||||
@@ -1629,23 +1621,21 @@ class GmailForwardBlock(GmailBase):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GoogleCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
service = self._build_service(credentials, **kwargs)
|
||||
result = await self._forward_message(
|
||||
service,
|
||||
input_data,
|
||||
graph_exec_id,
|
||||
user_id,
|
||||
execution_context,
|
||||
)
|
||||
yield "messageId", result["id"]
|
||||
yield "threadId", result.get("threadId", "")
|
||||
yield "status", "forwarded"
|
||||
|
||||
async def _forward_message(
|
||||
self, service, input_data: Input, graph_exec_id: str, user_id: str
|
||||
self, service, input_data: Input, execution_context: ExecutionContext
|
||||
) -> dict:
|
||||
if not input_data.to:
|
||||
raise ValueError("At least one recipient is required for forwarding")
|
||||
@@ -1727,12 +1717,12 @@ To: {original_to}
|
||||
# Add any additional attachments
|
||||
for attach in input_data.additionalAttachments:
|
||||
local_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=attach,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, local_path)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
abs_path = get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
part = MIMEBase("application", "octet-stream")
|
||||
with open(abs_path, "rb") as f:
|
||||
part.set_payload(f.read())
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
@@ -116,10 +117,9 @@ class SendWebRequestBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
async def _prepare_files(
|
||||
graph_exec_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
files_name: str,
|
||||
files: list[MediaFileType],
|
||||
user_id: str,
|
||||
) -> list[tuple[str, tuple[str, BytesIO, str]]]:
|
||||
"""
|
||||
Prepare files for the request by storing them and reading their content.
|
||||
@@ -127,11 +127,16 @@ class SendWebRequestBlock(Block):
|
||||
(files_name, (filename, BytesIO, mime_type))
|
||||
"""
|
||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
if graph_exec_id is None:
|
||||
raise ValueError("graph_exec_id is required for file operations")
|
||||
|
||||
for media in files:
|
||||
# Normalise to a list so we can repeat the same key
|
||||
rel_path = await store_media_file(
|
||||
graph_exec_id, media, user_id, return_content=False
|
||||
file=media,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
abs_path = get_exec_file_path(graph_exec_id, rel_path)
|
||||
async with aiofiles.open(abs_path, "rb") as f:
|
||||
@@ -143,7 +148,7 @@ class SendWebRequestBlock(Block):
|
||||
return files_payload
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **kwargs
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **kwargs
|
||||
) -> BlockOutput:
|
||||
# ─── Parse/normalise body ────────────────────────────────────
|
||||
body = input_data.body
|
||||
@@ -174,7 +179,7 @@ class SendWebRequestBlock(Block):
|
||||
files_payload: list[tuple[str, tuple[str, BytesIO, str]]] = []
|
||||
if use_files:
|
||||
files_payload = await self._prepare_files(
|
||||
graph_exec_id, input_data.files_name, input_data.files, user_id
|
||||
execution_context, input_data.files_name, input_data.files
|
||||
)
|
||||
|
||||
# Enforce body format rules
|
||||
@@ -238,9 +243,8 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
credentials: HostScopedCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create SendWebRequestBlock.Input from our input (removing credentials field)
|
||||
@@ -271,6 +275,6 @@ class SendAuthenticatedWebRequestBlock(SendWebRequestBlock):
|
||||
|
||||
# Use parent class run method
|
||||
async for output_name, output_data in super().run(
|
||||
base_input, graph_exec_id=graph_exec_id, user_id=user_id, **kwargs
|
||||
base_input, execution_context=execution_context, **kwargs
|
||||
):
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.mock import MockObject
|
||||
@@ -462,18 +463,21 @@ class AgentFileInputBlock(AgentInputBlock):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.value:
|
||||
return
|
||||
|
||||
# Determine return format based on user preference
|
||||
# for_external_api: always returns data URI (base64) - honors "Produce Base64 Output"
|
||||
# for_block_output: smart format - workspace:// in CoPilot, data URI in graphs
|
||||
return_format = "for_external_api" if input_data.base_64 else "for_block_output"
|
||||
|
||||
yield "result", await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.value,
|
||||
user_id=user_id,
|
||||
return_content=input_data.base_64,
|
||||
execution_context=execution_context,
|
||||
return_format=return_format,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
@@ -13,6 +13,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
@@ -46,18 +47,19 @@ class MediaDurationBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.media_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
media_abspath = get_exec_file_path(graph_exec_id, local_media_path)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
@@ -88,10 +90,6 @@ class LoopVideoBlock(Block):
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="How to return the output video. Either a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
@@ -111,17 +109,19 @@ class LoopVideoBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
@@ -149,12 +149,11 @@ class LoopVideoBlock(Block):
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return as data URI
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -177,10 +176,6 @@ class AddAudioToVideoBlock(Block):
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
output_return_type: Literal["file_path", "data_uri"] = SchemaField(
|
||||
description="Return the final output as a relative path or base64 data URI.",
|
||||
default="file_path",
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
@@ -200,23 +195,24 @@ class AddAudioToVideoBlock(Block):
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
node_exec_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.video_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.audio_in,
|
||||
user_id=user_id,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
@@ -240,12 +236,11 @@ class AddAudioToVideoBlock(Block):
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return either path or data URI
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=output_filename,
|
||||
user_id=user_id,
|
||||
return_content=input_data.output_return_type == "data_uri",
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -112,8 +113,7 @@ class ScreenshotWebPageBlock(Block):
|
||||
@staticmethod
|
||||
async def take_screenshot(
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
url: str,
|
||||
viewport_width: int,
|
||||
viewport_height: int,
|
||||
@@ -155,12 +155,11 @@ class ScreenshotWebPageBlock(Block):
|
||||
|
||||
return {
|
||||
"image": await store_media_file(
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=MediaFileType(
|
||||
f"data:image/{format.value};base64,{b64encode(content).decode('utf-8')}"
|
||||
),
|
||||
user_id=user_id,
|
||||
return_content=True,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -169,15 +168,13 @@ class ScreenshotWebPageBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
screenshot_data = await self.take_screenshot(
|
||||
credentials=credentials,
|
||||
graph_exec_id=graph_exec_id,
|
||||
user_id=user_id,
|
||||
execution_context=execution_context,
|
||||
url=input_data.url,
|
||||
viewport_width=input_data.viewport_width,
|
||||
viewport_height=input_data.viewport_height,
|
||||
|
||||
@@ -7,6 +7,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails, SchemaField
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
@@ -98,7 +99,7 @@ class ReadSpreadsheetBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||
) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
@@ -106,14 +107,16 @@ class ReadSpreadsheetBlock(Block):
|
||||
# Determine data source - prefer file_input if provided, otherwise use contents
|
||||
if input_data.file_input:
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Get full file path
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
assert execution_context.graph_exec_id # Validated by store_media_file
|
||||
file_path = get_exec_file_path(
|
||||
execution_context.graph_exec_id, stored_file_path
|
||||
)
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -17,7 +18,9 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.request import Requests
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
@@ -102,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
test_output=[
|
||||
(
|
||||
"video_url",
|
||||
"https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||
lambda x: x.startswith(("workspace://", "data:")),
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
@@ -110,9 +113,10 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
"id": "abcd1234-5678-efgh-ijkl-mnopqrstuvwx",
|
||||
"status": "created",
|
||||
},
|
||||
# Use data URI to avoid HTTP requests during tests
|
||||
"get_clip_status": lambda *args, **kwargs: {
|
||||
"status": "done",
|
||||
"result_url": "https://d-id.com/api/clips/abcd1234-5678-efgh-ijkl-mnopqrstuvwx/video",
|
||||
"result_url": "data:video/mp4;base64,AAAA",
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -138,7 +142,12 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
return response.json()
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Create the clip
|
||||
payload = {
|
||||
@@ -165,7 +174,14 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
for _ in range(input_data.max_polling_attempts):
|
||||
status_response = await self.get_clip_status(credentials.api_key, clip_id)
|
||||
if status_response["status"] == "done":
|
||||
yield "video_url", status_response["result_url"]
|
||||
# Store the generated video to the user's workspace for persistence
|
||||
video_url = status_response["result_url"]
|
||||
stored_url = await store_media_file(
|
||||
file=MediaFileType(video_url),
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
yield "video_url", stored_url
|
||||
return
|
||||
elif status_response["status"] == "error":
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.blocks.iteration import StepThroughItemsBlock
|
||||
from backend.blocks.llm import AITextSummarizerBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock
|
||||
from backend.blocks.xml_parser import XMLParserBlock
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
@@ -233,9 +234,12 @@ class TestStoreMediaFileSecurity:
|
||||
|
||||
with pytest.raises(ValueError, match="File too large"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(large_data_uri),
|
||||
user_id="test_user",
|
||||
execution_context=ExecutionContext(
|
||||
user_id="test_user",
|
||||
graph_exec_id="test",
|
||||
),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@patch("backend.util.file.Path")
|
||||
@@ -270,9 +274,12 @@ class TestStoreMediaFileSecurity:
|
||||
# Should raise an error when directory size exceeds limit
|
||||
with pytest.raises(ValueError, match="Disk usage limit exceeded"):
|
||||
await store_media_file(
|
||||
graph_exec_id="test",
|
||||
file=MediaFileType(
|
||||
"data:text/plain;base64,dGVzdA=="
|
||||
), # Small test file
|
||||
user_id="test_user",
|
||||
execution_context=ExecutionContext(
|
||||
user_id="test_user",
|
||||
graph_exec_id="test",
|
||||
),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
@@ -11,10 +11,22 @@ from backend.blocks.http import (
|
||||
HttpMethod,
|
||||
SendAuthenticatedWebRequestBlock,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import HostScopedCredentials
|
||||
from backend.util.request import Response
|
||||
|
||||
|
||||
def make_test_context(
|
||||
graph_exec_id: str = "test-exec-id",
|
||||
user_id: str = "test-user-id",
|
||||
) -> ExecutionContext:
|
||||
"""Helper to create test ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
class TestHttpBlockWithHostScopedCredentials:
|
||||
"""Test suite for HTTP block integration with HostScopedCredentials."""
|
||||
|
||||
@@ -105,8 +117,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -161,8 +172,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=wildcard_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -208,8 +218,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=non_matching_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -258,8 +267,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=exact_match_credentials,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -318,8 +326,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=auto_discovered_creds, # Execution manager found these
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -382,8 +389,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=multi_header_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
@@ -471,8 +477,7 @@ class TestHttpBlockWithHostScopedCredentials:
|
||||
async for output_name, output_data in http_block.run(
|
||||
input_data,
|
||||
credentials=test_creds,
|
||||
graph_exec_id="test-exec-id",
|
||||
user_id="test-user-id",
|
||||
execution_context=make_test_context(),
|
||||
):
|
||||
result.append((output_name, output_data))
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.data.block import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json, text
|
||||
from backend.util.file import get_exec_file_path, store_media_file
|
||||
@@ -444,18 +445,21 @@ class FileReadBlock(Block):
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, graph_exec_id: str, user_id: str, **_kwargs
|
||||
self, input_data: Input, *, execution_context: ExecutionContext, **_kwargs
|
||||
) -> BlockOutput:
|
||||
# Store the media file properly (handles URLs, data URIs, etc.)
|
||||
stored_file_path = await store_media_file(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
file=input_data.file_input,
|
||||
return_content=False,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Get full file path
|
||||
file_path = get_exec_file_path(graph_exec_id, stored_file_path)
|
||||
# Get full file path (graph_exec_id validated by store_media_file above)
|
||||
if not execution_context.graph_exec_id:
|
||||
raise ValueError("execution_context.graph_exec_id is required")
|
||||
file_path = get_exec_file_path(
|
||||
execution_context.graph_exec_id, stored_file_path
|
||||
)
|
||||
|
||||
if not Path(file_path).exists():
|
||||
raise ValueError(f"File does not exist: {file_path}")
|
||||
|
||||
@@ -83,12 +83,29 @@ class ExecutionContext(BaseModel):
|
||||
|
||||
model_config = {"extra": "ignore"}
|
||||
|
||||
# Execution identity
|
||||
user_id: Optional[str] = None
|
||||
graph_id: Optional[str] = None
|
||||
graph_exec_id: Optional[str] = None
|
||||
graph_version: Optional[int] = None
|
||||
node_id: Optional[str] = None
|
||||
node_exec_id: Optional[str] = None
|
||||
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode: bool = True
|
||||
sensitive_action_safe_mode: bool = False
|
||||
|
||||
# User settings
|
||||
user_timezone: str = "UTC"
|
||||
|
||||
# Execution hierarchy
|
||||
root_execution_id: Optional[str] = None
|
||||
parent_execution_id: Optional[str] = None
|
||||
|
||||
# Workspace
|
||||
workspace_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
# -------------------------- Models -------------------------- #
|
||||
|
||||
|
||||
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
276
autogpt_platform/backend/backend/data/workspace.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Database CRUD operations for User Workspace.
|
||||
|
||||
This module provides functions for managing user workspaces and workspace files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||
from prisma.types import UserWorkspaceFileWhereInput
|
||||
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
||||
"""
|
||||
Get user's workspace, creating one if it doesn't exist.
|
||||
|
||||
Uses upsert to handle race conditions when multiple concurrent requests
|
||||
attempt to create a workspace for the same user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance
|
||||
"""
|
||||
workspace = await UserWorkspace.prisma().upsert(
|
||||
where={"userId": user_id},
|
||||
data={
|
||||
"create": {"userId": user_id},
|
||||
"update": {}, # No updates needed if exists
|
||||
},
|
||||
)
|
||||
|
||||
return workspace
|
||||
|
||||
|
||||
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
||||
"""
|
||||
Get user's workspace if it exists.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspace instance or None
|
||||
"""
|
||||
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||
|
||||
|
||||
async def create_workspace_file(
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
name: str,
|
||||
path: str,
|
||||
storage_path: str,
|
||||
mime_type: str,
|
||||
size_bytes: int,
|
||||
checksum: Optional[str] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> UserWorkspaceFile:
|
||||
"""
|
||||
Create a new workspace file record.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
file_id: The file ID (same as used in storage path for consistency)
|
||||
name: User-visible filename
|
||||
path: Virtual path (e.g., "/documents/report.pdf")
|
||||
storage_path: Actual storage path (GCS or local)
|
||||
mime_type: MIME type of the file
|
||||
size_bytes: File size in bytes
|
||||
checksum: Optional SHA256 checksum
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
"""
|
||||
# Normalize path to start with /
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
file = await UserWorkspaceFile.prisma().create(
|
||||
data={
|
||||
"id": file_id,
|
||||
"workspaceId": workspace_id,
|
||||
"name": name,
|
||||
"path": path,
|
||||
"storagePath": storage_path,
|
||||
"mimeType": mime_type,
|
||||
"sizeBytes": size_bytes,
|
||||
"checksum": checksum,
|
||||
"metadata": SafeJson(metadata or {}),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created workspace file {file.id} at path {path} "
|
||||
f"in workspace {workspace_id}"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
async def get_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by ID.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
||||
if workspace_id:
|
||||
where_clause["workspaceId"] = workspace_id
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||
|
||||
|
||||
async def get_workspace_file_by_path(
|
||||
workspace_id: str,
|
||||
path: str,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get a workspace file by its virtual path.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
# Normalize path
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_first(
|
||||
where={
|
||||
"workspaceId": workspace_id,
|
||||
"path": path,
|
||||
"isDeleted": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def list_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
include_deleted: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""
|
||||
List files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path_prefix: Optional path prefix to filter (e.g., "/documents/")
|
||||
include_deleted: Whether to include soft-deleted files
|
||||
limit: Maximum number of files to return
|
||||
offset: Number of files to skip
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
"""
|
||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
if path_prefix:
|
||||
# Normalize prefix
|
||||
if not path_prefix.startswith("/"):
|
||||
path_prefix = f"/{path_prefix}"
|
||||
where_clause["path"] = {"startswith": path_prefix}
|
||||
|
||||
return await UserWorkspaceFile.prisma().find_many(
|
||||
where=where_clause,
|
||||
order={"createdAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
|
||||
async def count_workspace_files(
|
||||
workspace_id: str,
|
||||
path_prefix: Optional[str] = None,
|
||||
include_deleted: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Count files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
path_prefix: Optional path prefix to filter (e.g., "/sessions/abc123/")
|
||||
include_deleted: Whether to include soft-deleted files
|
||||
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
where_clause: dict = {"workspaceId": workspace_id}
|
||||
if not include_deleted:
|
||||
where_clause["isDeleted"] = False
|
||||
|
||||
if path_prefix:
|
||||
# Normalize prefix
|
||||
if not path_prefix.startswith("/"):
|
||||
path_prefix = f"/{path_prefix}"
|
||||
where_clause["path"] = {"startswith": path_prefix}
|
||||
|
||||
return await UserWorkspaceFile.prisma().count(where=where_clause)
|
||||
|
||||
|
||||
async def soft_delete_workspace_file(
|
||||
file_id: str,
|
||||
workspace_id: Optional[str] = None,
|
||||
) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Soft-delete a workspace file.
|
||||
|
||||
The path is modified to include a deletion timestamp to free up the original
|
||||
path for new files while preserving the record for potential recovery.
|
||||
|
||||
Args:
|
||||
file_id: The file ID
|
||||
workspace_id: Optional workspace ID for validation
|
||||
|
||||
Returns:
|
||||
Updated UserWorkspaceFile instance or None if not found
|
||||
"""
|
||||
# First verify the file exists and belongs to workspace
|
||||
file = await get_workspace_file(file_id, workspace_id)
|
||||
if file is None:
|
||||
return None
|
||||
|
||||
deleted_at = datetime.now(timezone.utc)
|
||||
# Modify path to free up the unique constraint for new files at original path
|
||||
# Format: {original_path}__deleted__{timestamp}
|
||||
deleted_path = f"{file.path}__deleted__{int(deleted_at.timestamp())}"
|
||||
|
||||
updated = await UserWorkspaceFile.prisma().update(
|
||||
where={"id": file_id},
|
||||
data={
|
||||
"isDeleted": True,
|
||||
"deletedAt": deleted_at,
|
||||
"path": deleted_path,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||
return updated
|
||||
|
||||
|
||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||
"""
|
||||
Get the total size of all files in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
files = await list_workspace_files(workspace_id)
|
||||
return sum(file.sizeBytes for file in files)
|
||||
@@ -236,7 +236,14 @@ async def execute_node(
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||||
|
||||
# Create node-specific execution context to avoid race conditions
|
||||
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||||
execution_context = execution_context.model_copy(
|
||||
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||||
)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||||
extra_exec_kwargs: dict = {
|
||||
"graph_id": graph_id,
|
||||
"graph_version": graph_version,
|
||||
|
||||
@@ -892,11 +892,19 @@ async def add_graph_execution(
|
||||
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
||||
|
||||
execution_context = ExecutionContext(
|
||||
# Execution identity
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec.id,
|
||||
graph_version=graph_exec.graph_version,
|
||||
# Safety settings
|
||||
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
|
||||
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
|
||||
# User settings
|
||||
user_timezone=(
|
||||
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
||||
),
|
||||
# Execution hierarchy
|
||||
root_execution_id=graph_exec.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -348,6 +348,7 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = [] # Add this to avoid AttributeError
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
mock_graph_exec.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Mock the queue and event bus
|
||||
@@ -434,6 +435,9 @@ async def test_add_graph_execution_is_repeatable(mocker: MockerFixture):
|
||||
# Create a second mock execution for the sanity check
|
||||
mock_graph_exec_2 = mocker.MagicMock(spec=GraphExecutionWithNodes)
|
||||
mock_graph_exec_2.id = "execution-id-456"
|
||||
mock_graph_exec_2.node_executions = []
|
||||
mock_graph_exec_2.status = ExecutionStatus.QUEUED
|
||||
mock_graph_exec_2.graph_version = graph_version
|
||||
mock_graph_exec_2.to_graph_execution_entry.return_value = mocker.MagicMock()
|
||||
|
||||
# Reset mocks and set up for second call
|
||||
@@ -614,6 +618,7 @@ async def test_add_graph_execution_with_nodes_to_skip(mocker: MockerFixture):
|
||||
mock_graph_exec.id = "execution-id-123"
|
||||
mock_graph_exec.node_executions = []
|
||||
mock_graph_exec.status = ExecutionStatus.QUEUED # Required for race condition check
|
||||
mock_graph_exec.graph_version = graph_version
|
||||
|
||||
# Track what's passed to to_graph_execution_entry
|
||||
captured_kwargs = {}
|
||||
|
||||
@@ -13,6 +13,7 @@ import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
from backend.util.gcs_utils import download_with_fresh_session, generate_signed_url
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -251,7 +252,7 @@ class CloudStorageHandler:
|
||||
f"in_task: {current_task is not None}"
|
||||
)
|
||||
|
||||
# Parse bucket and blob name from path
|
||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
@@ -261,50 +262,19 @@ class CloudStorageHandler:
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
# Use a fresh client for each download to avoid session issues
|
||||
# This is less efficient but more reliable with the executor's event loop
|
||||
logger.info("[CloudStorage] Creating fresh GCS client for download")
|
||||
|
||||
# Create a new session specifically for this download
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||
logger.info(
|
||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||
)
|
||||
|
||||
async_client = None
|
||||
try:
|
||||
# Create a new GCS client with the fresh session
|
||||
async_client = async_gcs_storage.Storage(session=session)
|
||||
|
||||
logger.info(
|
||||
f"[CloudStorage] About to download from GCS - bucket: {bucket_name}, blob: {blob_name}"
|
||||
)
|
||||
|
||||
# Download content using the fresh client
|
||||
content = await async_client.download(bucket_name, blob_name)
|
||||
content = await download_with_fresh_session(bucket_name, blob_name)
|
||||
logger.info(
|
||||
f"[CloudStorage] GCS download successful - size: {len(content)} bytes"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
await async_client.close()
|
||||
await session.close()
|
||||
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Always try to clean up
|
||||
if async_client is not None:
|
||||
try:
|
||||
await async_client.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"[CloudStorage] Error closing GCS client: {cleanup_error}"
|
||||
)
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"[CloudStorage] Error closing session: {cleanup_error}")
|
||||
|
||||
# Log the specific error for debugging
|
||||
logger.error(
|
||||
f"[CloudStorage] GCS download failed - error: {str(e)}, "
|
||||
@@ -319,10 +289,6 @@ class CloudStorageHandler:
|
||||
f"current_task: {current_task}, "
|
||||
f"bucket: {bucket_name}, blob: redacted for privacy"
|
||||
)
|
||||
|
||||
# Convert gcloud-aio exceptions to standard ones
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{path}")
|
||||
raise
|
||||
|
||||
def _validate_file_access(
|
||||
@@ -445,8 +411,7 @@ class CloudStorageHandler:
|
||||
graph_exec_id: str | None = None,
|
||||
) -> str:
|
||||
"""Generate signed URL for GCS with authorization."""
|
||||
|
||||
# Parse bucket and blob name from path
|
||||
# Parse bucket and blob name from path (path already has gcs:// prefix removed)
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
@@ -456,21 +421,11 @@ class CloudStorageHandler:
|
||||
# Authorization check
|
||||
self._validate_file_access(blob_name, user_id, graph_exec_id)
|
||||
|
||||
# Use sync client for signed URLs since gcloud-aio doesn't support them
|
||||
sync_client = self._get_sync_gcs_client()
|
||||
bucket = sync_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
|
||||
# Generate signed URL asynchronously using sync client
|
||||
url = await asyncio.to_thread(
|
||||
blob.generate_signed_url,
|
||||
version="v4",
|
||||
expiration=datetime.now(timezone.utc) + timedelta(hours=expiration_hours),
|
||||
method="GET",
|
||||
return await generate_signed_url(
|
||||
sync_client, bucket_name, blob_name, expiration_hours * 3600
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
async def delete_expired_files(self, provider: str = "gcs") -> int:
|
||||
"""
|
||||
Delete files that have passed their expiration time.
|
||||
|
||||
@@ -5,13 +5,26 @@ import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.request import Requests
|
||||
from backend.util.settings import Config
|
||||
from backend.util.type import MediaFileType
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
# Return format options for store_media_file
|
||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||
# - "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||
MediaReturnFormat = Literal[
|
||||
"for_local_processing", "for_external_api", "for_block_output"
|
||||
]
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
# Maximum filename length (conservative limit for most filesystems)
|
||||
@@ -67,42 +80,56 @@ def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
|
||||
|
||||
async def store_media_file(
|
||||
graph_exec_id: str,
|
||||
file: MediaFileType,
|
||||
user_id: str,
|
||||
return_content: bool = False,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""
|
||||
Safely handle 'file' (a data URI, a URL, or a local path relative to {temp}/exec_file/{exec_id}),
|
||||
placing or verifying it under:
|
||||
Safely handle 'file' (a data URI, a URL, a workspace:// reference, or a local path
|
||||
relative to {temp}/exec_file/{exec_id}), placing or verifying it under:
|
||||
{tempdir}/exec_file/{exec_id}/...
|
||||
|
||||
If 'return_content=True', return a data URI (data:<mime>;base64,<content>).
|
||||
Otherwise, returns the file media path relative to the exec_id folder.
|
||||
For each MediaFileType input:
|
||||
- Data URI: decode and store locally
|
||||
- URL: download and store locally
|
||||
- workspace:// reference: read from workspace, store locally
|
||||
- Local path: verify it exists in exec_file directory
|
||||
|
||||
For each MediaFileType type:
|
||||
- Data URI:
|
||||
-> decode and store in a new random file in that folder
|
||||
- URL:
|
||||
-> download and store in that folder
|
||||
- Local path:
|
||||
-> interpret as relative to that folder; verify it exists
|
||||
(no copying, as it's presumably already there).
|
||||
We realpath-check so no symlink or '..' can escape the folder.
|
||||
Return format options:
|
||||
- "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||
- "for_external_api": Returns data URI (base64) - use when sending to external APIs
|
||||
- "for_block_output": Returns best format for output - workspace:// in CoPilot, data URI in graphs
|
||||
|
||||
|
||||
:param graph_exec_id: The unique ID of the graph execution.
|
||||
:param file: Data URI, URL, or local (relative) path.
|
||||
:param return_content: If True, return a data URI of the file content.
|
||||
If False, return the *relative* path inside the exec_id folder.
|
||||
:return: The requested result: data URI or relative path of the media.
|
||||
:param file: Data URI, URL, workspace://, or local (relative) path.
|
||||
:param execution_context: ExecutionContext with user_id, graph_exec_id, workspace_id.
|
||||
:param return_format: What to return: "for_local_processing", "for_external_api", or "for_block_output".
|
||||
:return: The requested result based on return_format.
|
||||
"""
|
||||
# Extract values from execution_context
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
user_id = execution_context.user_id
|
||||
|
||||
if not graph_exec_id:
|
||||
raise ValueError("execution_context.graph_exec_id is required")
|
||||
if not user_id:
|
||||
raise ValueError("execution_context.user_id is required")
|
||||
|
||||
# Create workspace_manager if we have workspace_id (with session scoping)
|
||||
# Import here to avoid circular import (file.py → workspace.py → data → blocks → file.py)
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
workspace_manager: WorkspaceManager | None = None
|
||||
if execution_context.workspace_id:
|
||||
workspace_manager = WorkspaceManager(
|
||||
user_id, execution_context.workspace_id, execution_context.session_id
|
||||
)
|
||||
# Build base path
|
||||
base_path = Path(get_exec_file_path(graph_exec_id, ""))
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Security fix: Add disk space limits to prevent DoS
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB per file
|
||||
MAX_FILE_SIZE_BYTES = Config().max_file_size_mb * 1024 * 1024
|
||||
MAX_TOTAL_DISK_USAGE = 1024 * 1024 * 1024 # 1GB total per execution directory
|
||||
|
||||
# Check total disk usage in base_path
|
||||
@@ -142,9 +169,57 @@ async def store_media_file(
|
||||
"""
|
||||
return str(absolute_path.relative_to(base))
|
||||
|
||||
# Check if this is a cloud storage path
|
||||
# Get cloud storage handler for checking cloud paths
|
||||
cloud_storage = await get_cloud_storage_handler()
|
||||
if cloud_storage.is_cloud_path(file):
|
||||
|
||||
# Track if the input came from workspace (don't re-save it)
|
||||
is_from_workspace = file.startswith("workspace://")
|
||||
|
||||
# Check if this is a workspace file reference
|
||||
if is_from_workspace:
|
||||
if workspace_manager is None:
|
||||
raise ValueError(
|
||||
"Workspace file reference requires workspace context. "
|
||||
"This file type is only available in CoPilot sessions."
|
||||
)
|
||||
|
||||
# Parse workspace reference
|
||||
# workspace://abc123 - by file ID
|
||||
# workspace:///path/to/file.txt - by virtual path
|
||||
file_ref = file[12:] # Remove "workspace://"
|
||||
|
||||
if file_ref.startswith("/"):
|
||||
# Path reference
|
||||
workspace_content = await workspace_manager.read_file(file_ref)
|
||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
else:
|
||||
# ID reference
|
||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
||||
file_info = await workspace_manager.get_file_info(file_ref)
|
||||
filename = sanitize_filename(
|
||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||
)
|
||||
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Check file size limit
|
||||
if len(workspace_content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(workspace_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the workspace content before writing locally
|
||||
await scan_content_safe(workspace_content, filename=filename)
|
||||
target_path.write_bytes(workspace_content)
|
||||
|
||||
# Check if this is a cloud storage path
|
||||
elif cloud_storage.is_cloud_path(file):
|
||||
# Download from cloud storage and store locally
|
||||
cloud_content = await cloud_storage.retrieve_file(
|
||||
file, user_id=user_id, graph_exec_id=graph_exec_id
|
||||
@@ -159,9 +234,9 @@ async def store_media_file(
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
if len(cloud_content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
f"File too large: {len(cloud_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the cloud content before writing locally
|
||||
@@ -189,9 +264,9 @@ async def store_media_file(
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
if len(content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
f"File too large: {len(content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the base64 content before writing
|
||||
@@ -199,23 +274,31 @@ async def store_media_file(
|
||||
target_path.write_bytes(content)
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL
|
||||
# URL - download first to get Content-Type header
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||
)
|
||||
|
||||
# Extract filename from URL path
|
||||
parsed_url = urlparse(file)
|
||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||
|
||||
# If filename lacks extension, add one from Content-Type header
|
||||
if "." not in filename:
|
||||
content_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if content_type:
|
||||
ext = _extension_from_mime(content_type)
|
||||
filename = f"{filename}{ext}"
|
||||
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
|
||||
# Check file size limit
|
||||
if len(resp.content) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File too large: {len(resp.content)} bytes > {MAX_FILE_SIZE} bytes"
|
||||
)
|
||||
|
||||
# Virus scan the downloaded content before writing
|
||||
await scan_content_safe(resp.content, filename=filename)
|
||||
target_path.write_bytes(resp.content)
|
||||
@@ -230,12 +313,44 @@ async def store_media_file(
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
# Return result
|
||||
if return_content:
|
||||
return MediaFileType(_file_to_data_uri(target_path))
|
||||
else:
|
||||
# Return based on requested format
|
||||
if return_format == "for_local_processing":
|
||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||
# Returns: relative path in exec_file directory (e.g., "image.png")
|
||||
return MediaFileType(_strip_base_prefix(target_path, base_path))
|
||||
|
||||
elif return_format == "for_external_api":
|
||||
# Use when sending content to external APIs that need base64
|
||||
# Returns: data URI (e.g., "...")
|
||||
return MediaFileType(_file_to_data_uri(target_path))
|
||||
|
||||
elif return_format == "for_block_output":
|
||||
# Use when returning output from a block to user/next block
|
||||
# Returns: workspace:// ref (CoPilot) or data URI (graph execution)
|
||||
if workspace_manager is None:
|
||||
# No workspace available (graph execution without CoPilot)
|
||||
# Fallback to data URI so the content can still be used/displayed
|
||||
return MediaFileType(_file_to_data_uri(target_path))
|
||||
|
||||
# Don't re-save if input was already from workspace
|
||||
if is_from_workspace:
|
||||
# Return original workspace reference
|
||||
return MediaFileType(file)
|
||||
|
||||
# Save new content to workspace
|
||||
content = target_path.read_bytes()
|
||||
filename = target_path.name
|
||||
|
||||
file_record = await workspace_manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
overwrite=True,
|
||||
)
|
||||
return MediaFileType(f"workspace://{file_record.id}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid return_format: {return_format}")
|
||||
|
||||
|
||||
def get_dir_size(path: Path) -> int:
|
||||
"""Get total size of directory."""
|
||||
|
||||
@@ -7,10 +7,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
def make_test_context(
|
||||
graph_exec_id: str = "test-exec-123",
|
||||
user_id: str = "test-user-123",
|
||||
) -> ExecutionContext:
|
||||
"""Helper to create test ExecutionContext."""
|
||||
return ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
|
||||
|
||||
class TestFileCloudIntegration:
|
||||
"""Test cases for cloud storage integration in file utilities."""
|
||||
|
||||
@@ -70,10 +82,9 @@ class TestFileCloudIntegration:
|
||||
mock_path_class.side_effect = path_constructor
|
||||
|
||||
result = await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Verify cloud storage operations
|
||||
@@ -144,10 +155,9 @@ class TestFileCloudIntegration:
|
||||
mock_path_obj.name = "image.png"
|
||||
with patch("backend.util.file.Path", return_value=mock_path_obj):
|
||||
result = await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(cloud_path),
|
||||
"test-user-123",
|
||||
return_content=True,
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
# Verify result is a data URI
|
||||
@@ -198,10 +208,9 @@ class TestFileCloudIntegration:
|
||||
mock_resolved_path.relative_to.return_value = Path("test-uuid-789.txt")
|
||||
|
||||
await store_media_file(
|
||||
graph_exec_id,
|
||||
MediaFileType(data_uri),
|
||||
"test-user-123",
|
||||
return_content=False,
|
||||
file=MediaFileType(data_uri),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
# Verify cloud handler was checked but not used for retrieval
|
||||
@@ -234,5 +243,7 @@ class TestFileCloudIntegration:
|
||||
FileNotFoundError, match="File not found in cloud storage"
|
||||
):
|
||||
await store_media_file(
|
||||
graph_exec_id, MediaFileType(cloud_path), "test-user-123"
|
||||
file=MediaFileType(cloud_path),
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
108
autogpt_platform/backend/backend/util/gcs_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Shared GCS utilities for workspace and cloud storage backends.
|
||||
|
||||
This module provides common functionality for working with Google Cloud Storage,
|
||||
including path parsing, client management, and signed URL generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_gcs_path(path: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a GCS path in the format 'gcs://bucket/blob' to (bucket, blob).
|
||||
|
||||
Args:
|
||||
path: GCS path string (e.g., "gcs://my-bucket/path/to/file")
|
||||
|
||||
Returns:
|
||||
Tuple of (bucket_name, blob_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If the path format is invalid
|
||||
"""
|
||||
if not path.startswith("gcs://"):
|
||||
raise ValueError(f"Invalid GCS path: {path}")
|
||||
|
||||
path_without_prefix = path[6:] # Remove "gcs://"
|
||||
parts = path_without_prefix.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid GCS path format: {path}")
|
||||
|
||||
return parts[0], parts[1]
|
||||
|
||||
|
||||
async def download_with_fresh_session(bucket: str, blob: str) -> bytes:
|
||||
"""
|
||||
Download file content using a fresh session.
|
||||
|
||||
This approach avoids event loop issues that can occur when reusing
|
||||
sessions across different async contexts (e.g., in executors).
|
||||
|
||||
Args:
|
||||
bucket: GCS bucket name
|
||||
blob: Blob path within the bucket
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file doesn't exist
|
||||
"""
|
||||
session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=10, force_close=True)
|
||||
)
|
||||
client: async_gcs_storage.Storage | None = None
|
||||
try:
|
||||
client = async_gcs_storage.Storage(session=session)
|
||||
content = await client.download(bucket, blob)
|
||||
return content
|
||||
except Exception as e:
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise FileNotFoundError(f"File not found: gcs://{bucket}/{blob}")
|
||||
raise
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception:
|
||||
pass # Best-effort cleanup
|
||||
await session.close()
|
||||
|
||||
|
||||
async def generate_signed_url(
|
||||
sync_client: gcs_storage.Client,
|
||||
bucket_name: str,
|
||||
blob_name: str,
|
||||
expires_in: int,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a signed URL for temporary access to a GCS file.
|
||||
|
||||
Uses asyncio.to_thread() to run the sync operation without blocking.
|
||||
|
||||
Args:
|
||||
sync_client: Sync GCS client with service account credentials
|
||||
bucket_name: GCS bucket name
|
||||
blob_name: Blob path within the bucket
|
||||
expires_in: URL expiration time in seconds
|
||||
|
||||
Returns:
|
||||
Signed URL string
|
||||
"""
|
||||
bucket = sync_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
return await asyncio.to_thread(
|
||||
blob.generate_signed_url,
|
||||
version="v4",
|
||||
expiration=datetime.now(timezone.utc) + timedelta(seconds=expires_in),
|
||||
method="GET",
|
||||
)
|
||||
@@ -263,6 +263,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The name of the Google Cloud Storage bucket for media files",
|
||||
)
|
||||
|
||||
workspace_storage_dir: str = Field(
|
||||
default="",
|
||||
description="Local directory for workspace file storage when GCS is not configured. "
|
||||
"If empty, defaults to {app_data}/workspaces. Used for self-hosted deployments.",
|
||||
)
|
||||
|
||||
reddit_user_agent: str = Field(
|
||||
default="web:AutoGPT:v0.6.0 (by /u/autogpt)",
|
||||
description="The user agent for the Reddit API",
|
||||
@@ -389,6 +395,13 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Maximum file size in MB for file uploads (1-1024 MB)",
|
||||
)
|
||||
|
||||
max_file_size_mb: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1024,
|
||||
description="Maximum file size in MB for workspace files (1-1024 MB)",
|
||||
)
|
||||
|
||||
# AutoMod configuration
|
||||
automod_enabled: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -140,14 +140,29 @@ async def execute_block_test(block: Block):
|
||||
setattr(block, mock_name, mock_obj)
|
||||
|
||||
# Populate credentials argument(s)
|
||||
# Generate IDs for execution context
|
||||
graph_id = str(uuid.uuid4())
|
||||
node_id = str(uuid.uuid4())
|
||||
graph_exec_id = str(uuid.uuid4())
|
||||
node_exec_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
graph_version = 1 # Default version for tests
|
||||
|
||||
extra_exec_kwargs: dict = {
|
||||
"graph_id": str(uuid.uuid4()),
|
||||
"node_id": str(uuid.uuid4()),
|
||||
"graph_exec_id": str(uuid.uuid4()),
|
||||
"node_exec_id": str(uuid.uuid4()),
|
||||
"user_id": str(uuid.uuid4()),
|
||||
"graph_version": 1, # Default version for tests
|
||||
"execution_context": ExecutionContext(),
|
||||
"graph_id": graph_id,
|
||||
"node_id": node_id,
|
||||
"graph_exec_id": graph_exec_id,
|
||||
"node_exec_id": node_exec_id,
|
||||
"user_id": user_id,
|
||||
"graph_version": graph_version,
|
||||
"execution_context": ExecutionContext(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_version=graph_version,
|
||||
node_id=node_id,
|
||||
node_exec_id=node_exec_id,
|
||||
),
|
||||
}
|
||||
input_model = cast(type[BlockSchema], block.input_schema)
|
||||
|
||||
|
||||
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
419
autogpt_platform/backend/backend/util/workspace.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
WorkspaceManager for managing user workspace file operations.
|
||||
|
||||
This module provides a high-level interface for workspace file operations,
|
||||
combining the storage backend and database layer.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import UserWorkspaceFile
|
||||
|
||||
from backend.data.workspace import (
|
||||
count_workspace_files,
|
||||
create_workspace_file,
|
||||
get_workspace_file,
|
||||
get_workspace_file_by_path,
|
||||
list_workspace_files,
|
||||
soft_delete_workspace_file,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
Manages workspace file operations.
|
||||
|
||||
Combines storage backend operations with database record management.
|
||||
Supports session-scoped file segmentation where files are stored in
|
||||
session-specific virtual paths: /sessions/{session_id}/{filename}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, user_id: str, workspace_id: str, session_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize WorkspaceManager.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
workspace_id: The workspace ID
|
||||
session_id: Optional session ID for session-scoped file access
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.session_id = session_id
|
||||
# Session path prefix for file isolation
|
||||
self.session_path = f"/sessions/{session_id}" if session_id else ""
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve a path, defaulting to session folder if session_id is set.
|
||||
|
||||
Cross-session access is allowed by explicitly using /sessions/other-session-id/...
|
||||
|
||||
Args:
|
||||
path: Virtual path (e.g., "/file.txt" or "/sessions/abc123/file.txt")
|
||||
|
||||
Returns:
|
||||
Resolved path with session prefix if applicable
|
||||
"""
|
||||
# If path explicitly references a session folder, use it as-is
|
||||
if path.startswith("/sessions/"):
|
||||
return path
|
||||
|
||||
# If we have a session context, prepend session path
|
||||
if self.session_path:
|
||||
# Normalize the path
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
return f"{self.session_path}{path}"
|
||||
|
||||
# No session context, use path as-is
|
||||
return path if path.startswith("/") else f"/{path}"
|
||||
|
||||
def _get_effective_path(
|
||||
self, path: Optional[str], include_all_sessions: bool
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get effective path for list/count operations based on session context.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter
|
||||
include_all_sessions: If True, don't apply session scoping
|
||||
|
||||
Returns:
|
||||
Effective path prefix for database query
|
||||
"""
|
||||
if include_all_sessions:
|
||||
# Normalize path to ensure leading slash (stored paths are normalized)
|
||||
if path is not None and not path.startswith("/"):
|
||||
return f"/{path}"
|
||||
return path
|
||||
elif path is not None:
|
||||
# Resolve the provided path with session scoping
|
||||
return self._resolve_path(path)
|
||||
elif self.session_path:
|
||||
# Default to session folder with trailing slash to prevent prefix collisions
|
||||
# e.g., "/sessions/abc" should not match "/sessions/abc123"
|
||||
return self.session_path.rstrip("/") + "/"
|
||||
else:
|
||||
# No session context, use path as-is
|
||||
return path
|
||||
|
||||
async def read_file(self, path: str) -> bytes:
|
||||
"""
|
||||
Read file from workspace by virtual path.
|
||||
|
||||
When session_id is set, paths are resolved relative to the session folder
|
||||
unless they explicitly reference /sessions/...
|
||||
|
||||
Args:
|
||||
path: Virtual path (e.g., "/documents/report.pdf")
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
resolved_path = self._resolve_path(path)
|
||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
|
||||
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||
"""
|
||||
Read file from workspace by file ID.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.retrieve(file.storagePath)
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
path: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> UserWorkspaceFile:
|
||||
"""
|
||||
Write file to workspace.
|
||||
|
||||
When session_id is set, files are written to /sessions/{session_id}/...
|
||||
by default. Use explicit /sessions/... paths for cross-session access.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
filename: Filename for the file
|
||||
path: Virtual path (defaults to "/{filename}", session-scoped if session_id set)
|
||||
mime_type: MIME type (auto-detected if not provided)
|
||||
overwrite: Whether to overwrite existing file at path
|
||||
|
||||
Returns:
|
||||
Created UserWorkspaceFile instance
|
||||
|
||||
Raises:
|
||||
ValueError: If file exceeds size limit or path already exists
|
||||
"""
|
||||
# Enforce file size limit
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
raise ValueError(
|
||||
f"File too large: {len(content)} bytes exceeds "
|
||||
f"{Config().max_file_size_mb}MB limit"
|
||||
)
|
||||
|
||||
# Determine path with session scoping
|
||||
if path is None:
|
||||
path = f"/{filename}"
|
||||
elif not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
# Resolve path with session prefix
|
||||
path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists at path (only error for non-overwrite case)
|
||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||
# preventing data loss if the new write fails
|
||||
if not overwrite:
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing is not None:
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
|
||||
# Auto-detect MIME type if not provided
|
||||
if mime_type is None:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Compute checksum
|
||||
checksum = compute_file_checksum(content)
|
||||
|
||||
# Generate unique file ID for storage
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Store file in storage backend
|
||||
storage = await get_workspace_storage()
|
||||
storage_path = await storage.store(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Create database record - handle race condition where another request
|
||||
# created a file at the same path between our check and create
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# Race condition: another request created a file at this path
|
||||
if overwrite:
|
||||
# Re-fetch and delete the conflicting file, then retry
|
||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
||||
if existing:
|
||||
await self.delete_file(existing.id)
|
||||
# Retry the create - if this also fails, clean up storage file
|
||||
try:
|
||||
file = await create_workspace_file(
|
||||
workspace_id=self.workspace_id,
|
||||
file_id=file_id,
|
||||
name=filename,
|
||||
path=path,
|
||||
storage_path=storage_path,
|
||||
mime_type=mime_type,
|
||||
size_bytes=len(content),
|
||||
checksum=checksum,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up orphaned storage file on retry failure
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise
|
||||
else:
|
||||
# Clean up the orphaned storage file before raising
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise ValueError(f"File already exists at path: {path}")
|
||||
except Exception:
|
||||
# Any other database error (connection, validation, etc.) - clean up storage
|
||||
try:
|
||||
await storage.delete(storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up orphaned storage file: {e}")
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Wrote file {file.id} ({filename}) to workspace {self.workspace_id} "
|
||||
f"at path {path}, size={len(content)} bytes"
|
||||
)
|
||||
|
||||
return file
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
include_all_sessions: bool = False,
|
||||
) -> list[UserWorkspaceFile]:
|
||||
"""
|
||||
List files in workspace.
|
||||
|
||||
When session_id is set and include_all_sessions is False (default),
|
||||
only files in the current session's folder are listed.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter (e.g., "/documents/")
|
||||
limit: Maximum number of files to return
|
||||
offset: Number of files to skip
|
||||
include_all_sessions: If True, list files from all sessions.
|
||||
If False (default), only list current session's files.
|
||||
|
||||
Returns:
|
||||
List of UserWorkspaceFile instances
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
|
||||
return await list_workspace_files(
|
||||
workspace_id=self.workspace_id,
|
||||
path_prefix=effective_path,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
async def delete_file(self, file_id: str) -> bool:
|
||||
"""
|
||||
Delete a file (soft-delete).
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
return False
|
||||
|
||||
# Delete from storage
|
||||
storage = await get_workspace_storage()
|
||||
try:
|
||||
await storage.delete(file.storagePath)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file from storage: {e}")
|
||||
# Continue with database soft-delete even if storage delete fails
|
||||
|
||||
# Soft-delete database record
|
||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
||||
return result is not None
|
||||
|
||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get download URL for a file.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
expires_in: URL expiration in seconds (default 1 hour)
|
||||
|
||||
Returns:
|
||||
Download URL (signed URL for GCS, API endpoint for local)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
file = await get_workspace_file(file_id, self.workspace_id)
|
||||
if file is None:
|
||||
raise FileNotFoundError(f"File not found: {file_id}")
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
return await storage.get_download_url(file.storagePath, expires_in)
|
||||
|
||||
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get file metadata.
|
||||
|
||||
Args:
|
||||
file_id: The file's ID
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
return await get_workspace_file(file_id, self.workspace_id)
|
||||
|
||||
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
||||
"""
|
||||
Get file metadata by path.
|
||||
|
||||
When session_id is set, paths are resolved relative to the session folder
|
||||
unless they explicitly reference /sessions/...
|
||||
|
||||
Args:
|
||||
path: Virtual path
|
||||
|
||||
Returns:
|
||||
UserWorkspaceFile instance or None
|
||||
"""
|
||||
resolved_path = self._resolve_path(path)
|
||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||
|
||||
async def get_file_count(
|
||||
self,
|
||||
path: Optional[str] = None,
|
||||
include_all_sessions: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of files in workspace.
|
||||
|
||||
When session_id is set and include_all_sessions is False (default),
|
||||
only counts files in the current session's folder.
|
||||
|
||||
Args:
|
||||
path: Optional path prefix to filter (e.g., "/documents/")
|
||||
include_all_sessions: If True, count all files in workspace.
|
||||
If False (default), only count current session's files.
|
||||
|
||||
Returns:
|
||||
Number of files
|
||||
"""
|
||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||
|
||||
return await count_workspace_files(
|
||||
self.workspace_id, path_prefix=effective_path
|
||||
)
|
||||
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
398
autogpt_platform/backend/backend/util/workspace_storage.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Workspace storage backend abstraction for supporting both cloud and local deployments.
|
||||
|
||||
This module provides a unified interface for storing workspace files, with implementations
|
||||
for Google Cloud Storage (cloud deployments) and local filesystem (self-hosted deployments).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from gcloud.aio import storage as async_gcs_storage
|
||||
from google.cloud import storage as gcs_storage
|
||||
|
||||
from backend.util.data import get_data_path
|
||||
from backend.util.gcs_utils import (
|
||||
download_with_fresh_session,
|
||||
generate_signed_url,
|
||||
parse_gcs_path,
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceStorageBackend(ABC):
|
||||
"""Abstract interface for workspace file storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
Store file content, return storage path.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID
|
||||
file_id: Unique file ID for storage
|
||||
filename: Original filename
|
||||
content: File content as bytes
|
||||
|
||||
Returns:
|
||||
Storage path string (cloud path or local path)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""
|
||||
Retrieve file content from storage.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path returned from store()
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""
|
||||
Delete file from storage.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path to delete
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get URL for downloading the file.
|
||||
|
||||
Args:
|
||||
storage_path: The storage path
|
||||
expires_in: URL expiration time in seconds (default 1 hour)
|
||||
|
||||
Returns:
|
||||
Download URL (signed URL for GCS, direct API path for local)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GCSWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Google Cloud Storage implementation for workspace storage."""
|
||||
|
||||
def __init__(self, bucket_name: str):
|
||||
self.bucket_name = bucket_name
|
||||
self._async_client: Optional[async_gcs_storage.Storage] = None
|
||||
self._sync_client: Optional[gcs_storage.Client] = None
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_async_client(self) -> async_gcs_storage.Storage:
|
||||
"""Get or create async GCS client."""
|
||||
if self._async_client is None:
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=100, force_close=False)
|
||||
)
|
||||
self._async_client = async_gcs_storage.Storage(session=self._session)
|
||||
return self._async_client
|
||||
|
||||
def _get_sync_client(self) -> gcs_storage.Client:
|
||||
"""Get or create sync GCS client (for signed URLs)."""
|
||||
if self._sync_client is None:
|
||||
self._sync_client = gcs_storage.Client()
|
||||
return self._sync_client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all client connections."""
|
||||
if self._async_client is not None:
|
||||
try:
|
||||
await self._async_client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing GCS client: {e}")
|
||||
self._async_client = None
|
||||
|
||||
if self._session is not None:
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing session: {e}")
|
||||
self._session = None
|
||||
|
||||
def _build_blob_name(self, workspace_id: str, file_id: str, filename: str) -> str:
|
||||
"""Build the blob path for workspace files."""
|
||||
return f"workspaces/{workspace_id}/{file_id}/{filename}"
|
||||
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Store file in GCS."""
|
||||
client = await self._get_async_client()
|
||||
blob_name = self._build_blob_name(workspace_id, file_id, filename)
|
||||
|
||||
# Upload with metadata
|
||||
upload_time = datetime.now(timezone.utc)
|
||||
await client.upload(
|
||||
self.bucket_name,
|
||||
blob_name,
|
||||
content,
|
||||
metadata={
|
||||
"uploaded_at": upload_time.isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"file_id": file_id,
|
||||
},
|
||||
)
|
||||
|
||||
return f"gcs://{self.bucket_name}/{blob_name}"
|
||||
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""Retrieve file from GCS."""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
return await download_with_fresh_session(bucket_name, blob_name)
|
||||
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""Delete file from GCS."""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
client = await self._get_async_client()
|
||||
|
||||
try:
|
||||
await client.delete(bucket_name, blob_name)
|
||||
except Exception as e:
|
||||
if "404" not in str(e) and "Not Found" not in str(e):
|
||||
raise
|
||||
# File already deleted, that's fine
|
||||
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate download URL for GCS file.
|
||||
|
||||
Attempts to generate a signed URL if running with service account credentials.
|
||||
Falls back to an API proxy endpoint if signed URL generation fails
|
||||
(e.g., when running locally with user OAuth credentials).
|
||||
"""
|
||||
bucket_name, blob_name = parse_gcs_path(storage_path)
|
||||
|
||||
# Extract file_id from blob_name for fallback: workspaces/{workspace_id}/{file_id}/{filename}
|
||||
blob_parts = blob_name.split("/")
|
||||
file_id = blob_parts[2] if len(blob_parts) >= 3 else None
|
||||
|
||||
# Try to generate signed URL (requires service account credentials)
|
||||
try:
|
||||
sync_client = self._get_sync_client()
|
||||
return await generate_signed_url(
|
||||
sync_client, bucket_name, blob_name, expires_in
|
||||
)
|
||||
except AttributeError as e:
|
||||
# Signed URL generation requires service account with private key.
|
||||
# When running with user OAuth credentials, fall back to API proxy.
|
||||
if "private key" in str(e) and file_id:
|
||||
logger.debug(
|
||||
"Cannot generate signed URL (no service account credentials), "
|
||||
"falling back to API proxy endpoint"
|
||||
)
|
||||
return f"/api/workspace/files/{file_id}/download"
|
||||
raise
|
||||
|
||||
|
||||
class LocalWorkspaceStorage(WorkspaceStorageBackend):
|
||||
"""Local filesystem implementation for workspace storage (self-hosted deployments)."""
|
||||
|
||||
def __init__(self, base_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize local storage backend.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for workspace storage.
|
||||
If None, defaults to {app_data}/workspaces
|
||||
"""
|
||||
if base_dir:
|
||||
self.base_dir = Path(base_dir)
|
||||
else:
|
||||
self.base_dir = Path(get_data_path()) / "workspaces"
|
||||
|
||||
# Ensure base directory exists
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _build_file_path(self, workspace_id: str, file_id: str, filename: str) -> Path:
|
||||
"""Build the local file path with path traversal protection."""
|
||||
# Import here to avoid circular import
|
||||
# (file.py imports workspace.py which imports workspace_storage.py)
|
||||
from backend.util.file import sanitize_filename
|
||||
|
||||
# Sanitize filename to prevent path traversal (removes / and \ among others)
|
||||
safe_filename = sanitize_filename(filename)
|
||||
file_path = (self.base_dir / workspace_id / file_id / safe_filename).resolve()
|
||||
|
||||
# Verify the resolved path is still under base_dir
|
||||
if not file_path.is_relative_to(self.base_dir.resolve()):
|
||||
raise ValueError("Invalid filename: path traversal detected")
|
||||
|
||||
return file_path
|
||||
|
||||
def _parse_storage_path(self, storage_path: str) -> Path:
|
||||
"""Parse local storage path to filesystem path."""
|
||||
if storage_path.startswith("local://"):
|
||||
relative_path = storage_path[8:] # Remove "local://"
|
||||
else:
|
||||
relative_path = storage_path
|
||||
|
||||
full_path = (self.base_dir / relative_path).resolve()
|
||||
|
||||
# Security check: ensure path is under base_dir
|
||||
# Use is_relative_to() for robust path containment check
|
||||
# (handles case-insensitive filesystems and edge cases)
|
||||
if not full_path.is_relative_to(self.base_dir.resolve()):
|
||||
raise ValueError("Invalid storage path: path traversal detected")
|
||||
|
||||
return full_path
|
||||
|
||||
async def store(
|
||||
self,
|
||||
workspace_id: str,
|
||||
file_id: str,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Store file locally."""
|
||||
file_path = self._build_file_path(workspace_id, file_id, filename)
|
||||
|
||||
# Create parent directories
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write file asynchronously
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# Return relative path as storage path
|
||||
relative_path = file_path.relative_to(self.base_dir)
|
||||
return f"local://{relative_path}"
|
||||
|
||||
async def retrieve(self, storage_path: str) -> bytes:
|
||||
"""Retrieve file from local storage."""
|
||||
file_path = self._parse_storage_path(storage_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {storage_path}")
|
||||
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
return await f.read()
|
||||
|
||||
async def delete(self, storage_path: str) -> None:
|
||||
"""Delete file from local storage."""
|
||||
file_path = self._parse_storage_path(storage_path)
|
||||
|
||||
if file_path.exists():
|
||||
# Remove file
|
||||
file_path.unlink()
|
||||
|
||||
# Clean up empty parent directories
|
||||
parent = file_path.parent
|
||||
while parent != self.base_dir:
|
||||
try:
|
||||
if parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
else:
|
||||
break
|
||||
except OSError:
|
||||
break
|
||||
parent = parent.parent
|
||||
|
||||
async def get_download_url(self, storage_path: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Get download URL for local file.
|
||||
|
||||
For local storage, this returns an API endpoint path.
|
||||
The actual serving is handled by the API layer.
|
||||
"""
|
||||
# Parse the storage path to get the components
|
||||
if storage_path.startswith("local://"):
|
||||
relative_path = storage_path[8:]
|
||||
else:
|
||||
relative_path = storage_path
|
||||
|
||||
# Return the API endpoint for downloading
|
||||
# The file_id is extracted from the path: {workspace_id}/{file_id}/{filename}
|
||||
parts = relative_path.split("/")
|
||||
if len(parts) >= 2:
|
||||
file_id = parts[1] # Second component is file_id
|
||||
return f"/api/workspace/files/{file_id}/download"
|
||||
else:
|
||||
raise ValueError(f"Invalid storage path format: {storage_path}")
|
||||
|
||||
|
||||
# Global storage backend instance
|
||||
_workspace_storage: Optional[WorkspaceStorageBackend] = None
|
||||
_storage_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_workspace_storage() -> WorkspaceStorageBackend:
|
||||
"""
|
||||
Get the workspace storage backend instance.
|
||||
|
||||
Uses GCS if media_gcs_bucket_name is configured, otherwise uses local storage.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
if _workspace_storage is None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is None:
|
||||
config = Config()
|
||||
|
||||
if config.media_gcs_bucket_name:
|
||||
logger.info(
|
||||
f"Using GCS workspace storage: {config.media_gcs_bucket_name}"
|
||||
)
|
||||
_workspace_storage = GCSWorkspaceStorage(
|
||||
config.media_gcs_bucket_name
|
||||
)
|
||||
else:
|
||||
storage_dir = (
|
||||
config.workspace_storage_dir
|
||||
if config.workspace_storage_dir
|
||||
else None
|
||||
)
|
||||
logger.info(
|
||||
f"Using local workspace storage: {storage_dir or 'default'}"
|
||||
)
|
||||
_workspace_storage = LocalWorkspaceStorage(storage_dir)
|
||||
|
||||
return _workspace_storage
|
||||
|
||||
|
||||
async def shutdown_workspace_storage() -> None:
|
||||
"""
|
||||
Properly shutdown the global workspace storage backend.
|
||||
|
||||
Closes aiohttp sessions and other resources for GCS backend.
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _workspace_storage
|
||||
|
||||
if _workspace_storage is not None:
|
||||
async with _storage_lock:
|
||||
if _workspace_storage is not None:
|
||||
if isinstance(_workspace_storage, GCSWorkspaceStorage):
|
||||
await _workspace_storage.close()
|
||||
_workspace_storage = None
|
||||
|
||||
|
||||
def compute_file_checksum(content: bytes) -> str:
|
||||
"""Compute SHA256 checksum of file content."""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
@@ -0,0 +1,52 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "WorkspaceFileSource" AS ENUM ('UPLOAD', 'EXECUTION', 'COPILOT', 'IMPORT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserWorkspace" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "UserWorkspace_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserWorkspaceFile" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"workspaceId" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"path" TEXT NOT NULL,
|
||||
"storagePath" TEXT NOT NULL,
|
||||
"mimeType" TEXT NOT NULL,
|
||||
"sizeBytes" BIGINT NOT NULL,
|
||||
"checksum" TEXT,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
"deletedAt" TIMESTAMP(3),
|
||||
"source" "WorkspaceFileSource" NOT NULL DEFAULT 'UPLOAD',
|
||||
"sourceExecId" TEXT,
|
||||
"sourceSessionId" TEXT,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UserWorkspaceFile_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "UserWorkspace_userId_key" ON "UserWorkspace"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserWorkspace_userId_idx" ON "UserWorkspace"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UserWorkspaceFile_workspaceId_isDeleted_idx" ON "UserWorkspaceFile"("workspaceId", "isDeleted");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "UserWorkspaceFile_workspaceId_path_key" ON "UserWorkspaceFile"("workspaceId", "path");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserWorkspace" ADD CONSTRAINT "UserWorkspace_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserWorkspaceFile" ADD CONSTRAINT "UserWorkspaceFile_workspaceId_fkey" FOREIGN KEY ("workspaceId") REFERENCES "UserWorkspace"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,16 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `source` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `sourceExecId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||
- You are about to drop the column `sourceSessionId` on the `UserWorkspaceFile` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "UserWorkspaceFile" DROP COLUMN "source",
|
||||
DROP COLUMN "sourceExecId",
|
||||
DROP COLUMN "sourceSessionId";
|
||||
|
||||
-- DropEnum
|
||||
DROP TYPE "WorkspaceFileSource";
|
||||
@@ -63,6 +63,7 @@ model User {
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
PendingHumanReviews PendingHumanReview[]
|
||||
Workspace UserWorkspace?
|
||||
|
||||
// OAuth Provider relations
|
||||
OAuthApplications OAuthApplication[]
|
||||
@@ -137,6 +138,53 @@ model CoPilotUnderstanding {
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
//////////////// USER WORKSPACE TABLES /////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////
|
||||
|
||||
// User's persistent file storage workspace
|
||||
model UserWorkspace {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
userId String @unique
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
Files UserWorkspaceFile[]
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// Individual files in a user's workspace
|
||||
model UserWorkspaceFile {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
workspaceId String
|
||||
Workspace UserWorkspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
|
||||
// File metadata
|
||||
name String // User-visible filename
|
||||
path String // Virtual path (e.g., "/documents/report.pdf")
|
||||
storagePath String // Actual GCS or local storage path
|
||||
mimeType String
|
||||
sizeBytes BigInt
|
||||
checksum String? // SHA256 for integrity
|
||||
|
||||
// File state
|
||||
isDeleted Boolean @default(false)
|
||||
deletedAt DateTime?
|
||||
|
||||
metadata Json @default("{}")
|
||||
|
||||
@@unique([workspaceId, path])
|
||||
@@index([workspaceId, isDeleted])
|
||||
}
|
||||
|
||||
model BuilderSearchHistory {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
@@ -151,15 +151,20 @@ class TestDecomposeGoalExternal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_http_error(self):
|
||||
"""Test decomposition handles HTTP errors gracefully."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error", request=MagicMock(), response=MagicMock()
|
||||
"Server error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "http_error"
|
||||
assert "Server error" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_request_error(self):
|
||||
@@ -170,7 +175,10 @@ class TestDecomposeGoalExternal:
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_handles_service_error(self):
|
||||
@@ -179,6 +187,7 @@ class TestDecomposeGoalExternal:
|
||||
mock_response.json.return_value = {
|
||||
"success": False,
|
||||
"error": "Internal error",
|
||||
"error_type": "internal_error",
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
@@ -188,7 +197,10 @@ class TestDecomposeGoalExternal:
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.decompose_goal_external("Build a chatbot")
|
||||
|
||||
assert result is None
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error") == "Internal error"
|
||||
assert result.get("error_type") == "internal_error"
|
||||
|
||||
|
||||
class TestGenerateAgentExternal:
|
||||
@@ -236,7 +248,10 @@ class TestGenerateAgentExternal:
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
result = await service.generate_agent_external({"steps": []})
|
||||
|
||||
assert result is None
|
||||
assert result is not None
|
||||
assert result.get("type") == "error"
|
||||
assert result.get("error_type") == "connection_error"
|
||||
assert "Connection failed" in result.get("error", "")
|
||||
|
||||
|
||||
class TestGenerateAgentPatchExternal:
|
||||
|
||||
@@ -34,3 +34,6 @@ NEXT_PUBLIC_PREVIEW_STEALING_DEV=
|
||||
# PostHog Analytics
|
||||
NEXT_PUBLIC_POSTHOG_KEY=
|
||||
NEXT_PUBLIC_POSTHOG_HOST=https://eu.i.posthog.com
|
||||
|
||||
# OpenAI (for voice transcription)
|
||||
OPENAI_API_KEY=
|
||||
|
||||
76
autogpt_platform/frontend/CLAUDE.md
Normal file
76
autogpt_platform/frontend/CLAUDE.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# CLAUDE.md - Frontend
|
||||
|
||||
This file provides guidance to Claude Code when working with the frontend.
|
||||
|
||||
## Essential Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pnpm i
|
||||
|
||||
# Generate API client from OpenAPI spec
|
||||
pnpm generate:api
|
||||
|
||||
# Start development server
|
||||
pnpm dev
|
||||
|
||||
# Run E2E tests
|
||||
pnpm test
|
||||
|
||||
# Run Storybook for component development
|
||||
pnpm storybook
|
||||
|
||||
# Build production
|
||||
pnpm build
|
||||
|
||||
# Format and lint
|
||||
pnpm format
|
||||
|
||||
# Type checking
|
||||
pnpm types
|
||||
```
|
||||
|
||||
### Code Style
|
||||
|
||||
- Fully capitalize acronyms in symbols, e.g. `graphID`, `useBackendAPI`
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: Next.js 15 App Router (client-first approach)
|
||||
- **Data Fetching**: Type-safe generated API hooks via Orval + React Query
|
||||
- **State Management**: React Query for server state, co-located UI state in components/hooks
|
||||
- **Component Structure**: Separate render logic (`.tsx`) from business logic (`use*.ts` hooks)
|
||||
- **Workflow Builder**: Visual graph editor using @xyflow/react
|
||||
- **UI Components**: shadcn/ui (Radix UI primitives) with Tailwind CSS styling
|
||||
- **Icons**: Phosphor Icons only
|
||||
- **Feature Flags**: LaunchDarkly integration
|
||||
- **Error Handling**: ErrorCard for render errors, toast for mutations, Sentry for exceptions
|
||||
- **Testing**: Playwright for E2E, Storybook for component development
|
||||
|
||||
## Environment Configuration
|
||||
|
||||
`.env.default` (defaults) → `.env` (user overrides)
|
||||
|
||||
## Feature Development
|
||||
|
||||
See @CONTRIBUTING.md for complete patterns. Quick reference:
|
||||
|
||||
1. **Pages**: Create in `src/app/(platform)/feature-name/page.tsx`
|
||||
- Extract component logic into custom hooks grouped by concern, not by component. Each hook should represent a cohesive domain of functionality (e.g., useSearch, useFilters, usePagination) rather than bundling all state into one useComponentState hook.
|
||||
- Put each hook in its own `.ts` file
|
||||
- Put sub-components in local `components/` folder
|
||||
- Component props should be `type Props = { ... }` (not exported) unless it needs to be used outside the component
|
||||
2. **Components**: Structure as `ComponentName/ComponentName.tsx` + `useComponentName.ts` + `helpers.ts`
|
||||
- Use design system components from `src/components/` (atoms, molecules, organisms)
|
||||
- Never use `src/components/__legacy__/*`
|
||||
3. **Data fetching**: Use generated API hooks from `@/app/api/__generated__/endpoints/`
|
||||
- Regenerate with `pnpm generate:api`
|
||||
- Pattern: `use{Method}{Version}{OperationName}`
|
||||
4. **Styling**: Tailwind CSS only, use design tokens, Phosphor Icons only
|
||||
5. **Testing**: Add Storybook stories for new components, Playwright for E2E
|
||||
6. **Code conventions**:
|
||||
- Use function declarations (not arrow functions) for components/handlers
|
||||
- Do not use `useCallback` or `useMemo` unless asked to optimise a given function
|
||||
- Do not type hook returns, let Typescript infer as much as possible
|
||||
- Never type with `any` unless a variable/attribute can ACTUALLY be of any type
|
||||
@@ -73,9 +73,9 @@ export function useSessionsPagination({ enabled }: UseSessionsPaginationArgs) {
|
||||
};
|
||||
|
||||
const reset = () => {
|
||||
// Only reset the offset - keep existing sessions visible during refetch
|
||||
// The effect will replace sessions when new data arrives at offset 0
|
||||
setOffset(0);
|
||||
setAccumulatedSessions([]);
|
||||
setTotalCount(null);
|
||||
};
|
||||
|
||||
return {
|
||||
|
||||
@@ -5912,6 +5912,40 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/workspace/files/{file_id}/download": {
|
||||
"get": {
|
||||
"tags": ["workspace"],
|
||||
"summary": "Download file by ID",
|
||||
"description": "Download a file by its ID.\n\nReturns the file content directly or redirects to a signed URL for GCS.",
|
||||
"operationId": "getWorkspaceDownload file by id",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "file_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "File Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/health": {
|
||||
"get": {
|
||||
"tags": ["health"],
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import {
|
||||
ApiError,
|
||||
getServerAuthToken,
|
||||
makeAuthenticatedFileUpload,
|
||||
makeAuthenticatedRequest,
|
||||
} from "@/lib/autogpt-server-api/helpers";
|
||||
@@ -15,6 +16,69 @@ function buildBackendUrl(path: string[], queryString: string): string {
|
||||
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is a workspace file download request that needs binary response handling.
|
||||
*/
|
||||
function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||
// Match pattern: api/workspace/files/{id}/download (5 segments)
|
||||
return (
|
||||
path.length == 5 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "workspace" &&
|
||||
path[2] === "files" &&
|
||||
path[path.length - 1] === "download"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle workspace file download requests with proper binary response streaming.
|
||||
*/
|
||||
async function handleWorkspaceDownload(
|
||||
req: NextRequest,
|
||||
backendUrl: string,
|
||||
): Promise<NextResponse> {
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers,
|
||||
redirect: "follow", // Follow redirects to signed URLs
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to download file: ${response.statusText}` },
|
||||
{ status: response.status },
|
||||
);
|
||||
}
|
||||
|
||||
// Get the content type from the backend response
|
||||
const contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
const contentDisposition = response.headers.get("Content-Disposition");
|
||||
|
||||
// Stream the response body
|
||||
const responseHeaders: Record<string, string> = {
|
||||
"Content-Type": contentType,
|
||||
};
|
||||
|
||||
if (contentDisposition) {
|
||||
responseHeaders["Content-Disposition"] = contentDisposition;
|
||||
}
|
||||
|
||||
// Return the binary content
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
return new NextResponse(arrayBuffer, {
|
||||
status: 200,
|
||||
headers: responseHeaders,
|
||||
});
|
||||
}
|
||||
|
||||
async function handleJsonRequest(
|
||||
req: NextRequest,
|
||||
method: string,
|
||||
@@ -180,6 +244,11 @@ async function handler(
|
||||
};
|
||||
|
||||
try {
|
||||
// Handle workspace file downloads separately (binary response)
|
||||
if (method === "GET" && isWorkspaceDownloadRequest(path)) {
|
||||
return await handleWorkspaceDownload(req, backendUrl);
|
||||
}
|
||||
|
||||
if (method === "GET" || method === "DELETE") {
|
||||
responseBody = await handleGetDeleteRequest(method, backendUrl, req);
|
||||
} else if (contentType?.includes("application/json")) {
|
||||
|
||||
77
autogpt_platform/frontend/src/app/api/transcribe/route.ts
Normal file
77
autogpt_platform/frontend/src/app/api/transcribe/route.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
const WHISPER_API_URL = "https://api.openai.com/v1/audio/transcriptions";
|
||||
const MAX_FILE_SIZE = 25 * 1024 * 1024; // 25MB - Whisper's limit
|
||||
|
||||
function getExtensionFromMimeType(mimeType: string): string {
|
||||
const subtype = mimeType.split("/")[1]?.split(";")[0];
|
||||
return subtype || "webm";
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
if (!token || token === "no-token-found") {
|
||||
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
|
||||
}
|
||||
|
||||
const apiKey = process.env.OPENAI_API_KEY;
|
||||
|
||||
if (!apiKey) {
|
||||
return NextResponse.json(
|
||||
{ error: "OpenAI API key not configured" },
|
||||
{ status: 401 },
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const formData = await request.formData();
|
||||
const audioFile = formData.get("audio");
|
||||
|
||||
if (!audioFile || !(audioFile instanceof Blob)) {
|
||||
return NextResponse.json(
|
||||
{ error: "No audio file provided" },
|
||||
{ status: 400 },
|
||||
);
|
||||
}
|
||||
|
||||
if (audioFile.size > MAX_FILE_SIZE) {
|
||||
return NextResponse.json(
|
||||
{ error: "File too large. Maximum size is 25MB." },
|
||||
{ status: 413 },
|
||||
);
|
||||
}
|
||||
|
||||
const ext = getExtensionFromMimeType(audioFile.type);
|
||||
const whisperFormData = new FormData();
|
||||
whisperFormData.append("file", audioFile, `recording.${ext}`);
|
||||
whisperFormData.append("model", "whisper-1");
|
||||
|
||||
const response = await fetch(WHISPER_API_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: whisperFormData,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
console.error("Whisper API error:", errorData);
|
||||
return NextResponse.json(
|
||||
{ error: errorData.error?.message || "Transcription failed" },
|
||||
{ status: response.status },
|
||||
);
|
||||
}
|
||||
|
||||
const result = await response.json();
|
||||
return NextResponse.json({ text: result.text });
|
||||
} catch (error) {
|
||||
console.error("Transcription error:", error);
|
||||
return NextResponse.json(
|
||||
{ error: "Failed to process audio" },
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,14 @@
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ArrowUpIcon, StopIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
ArrowUpIcon,
|
||||
CircleNotchIcon,
|
||||
MicrophoneIcon,
|
||||
StopIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
export interface Props {
|
||||
onSend: (message: string) => void;
|
||||
@@ -21,13 +28,36 @@ export function ChatInput({
|
||||
className,
|
||||
}: Props) {
|
||||
const inputId = "chat-input";
|
||||
const { value, handleKeyDown, handleSubmit, handleChange, hasMultipleLines } =
|
||||
useChatInput({
|
||||
onSend,
|
||||
disabled: disabled || isStreaming,
|
||||
maxRows: 4,
|
||||
inputId,
|
||||
});
|
||||
const {
|
||||
value,
|
||||
setValue,
|
||||
handleKeyDown: baseHandleKeyDown,
|
||||
handleSubmit,
|
||||
handleChange,
|
||||
hasMultipleLines,
|
||||
} = useChatInput({
|
||||
onSend,
|
||||
disabled: disabled || isStreaming,
|
||||
maxRows: 4,
|
||||
inputId,
|
||||
});
|
||||
|
||||
const {
|
||||
isRecording,
|
||||
isTranscribing,
|
||||
elapsedTime,
|
||||
toggleRecording,
|
||||
handleKeyDown,
|
||||
showMicButton,
|
||||
isInputDisabled,
|
||||
audioStream,
|
||||
} = useVoiceRecording({
|
||||
setValue,
|
||||
disabled: disabled || isStreaming,
|
||||
isStreaming,
|
||||
value,
|
||||
baseHandleKeyDown,
|
||||
});
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className={cn("relative flex-1", className)}>
|
||||
@@ -35,8 +65,11 @@ export function ChatInput({
|
||||
<div
|
||||
id={`${inputId}-wrapper`}
|
||||
className={cn(
|
||||
"relative overflow-hidden border border-neutral-200 bg-white shadow-sm",
|
||||
"focus-within:border-zinc-400 focus-within:ring-1 focus-within:ring-zinc-400",
|
||||
"relative overflow-hidden border bg-white shadow-sm",
|
||||
"focus-within:ring-1",
|
||||
isRecording
|
||||
? "border-red-400 focus-within:border-red-400 focus-within:ring-red-400"
|
||||
: "border-neutral-200 focus-within:border-zinc-400 focus-within:ring-zinc-400",
|
||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||
)}
|
||||
>
|
||||
@@ -46,48 +79,94 @@ export function ChatInput({
|
||||
value={value}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={disabled || isStreaming}
|
||||
placeholder={
|
||||
isTranscribing
|
||||
? "Transcribing..."
|
||||
: isRecording
|
||||
? ""
|
||||
: placeholder
|
||||
}
|
||||
disabled={isInputDisabled}
|
||||
rows={1}
|
||||
className={cn(
|
||||
"w-full resize-none overflow-y-auto border-0 bg-transparent text-[1rem] leading-6 text-black",
|
||||
"placeholder:text-zinc-400",
|
||||
"focus:outline-none focus:ring-0",
|
||||
"disabled:text-zinc-500",
|
||||
hasMultipleLines ? "pb-6 pl-4 pr-4 pt-2" : "pb-4 pl-4 pr-14 pt-4",
|
||||
hasMultipleLines
|
||||
? "pb-6 pl-4 pr-4 pt-2"
|
||||
: showMicButton
|
||||
? "pb-4 pl-14 pr-14 pt-4"
|
||||
: "pb-4 pl-4 pr-14 pt-4",
|
||||
)}
|
||||
/>
|
||||
{isRecording && !value && (
|
||||
<div className="pointer-events-none absolute inset-0 flex items-center justify-center">
|
||||
<RecordingIndicator
|
||||
elapsedTime={elapsedTime}
|
||||
audioStream={audioStream}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<span id="chat-input-hint" className="sr-only">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
Press Enter to send, Shift+Enter for new line, Space to record voice
|
||||
</span>
|
||||
|
||||
{isStreaming ? (
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Stop generating"
|
||||
onClick={onStop}
|
||||
className="absolute bottom-[7px] right-2 border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
||||
>
|
||||
<StopIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type="submit"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Send message"
|
||||
className={cn(
|
||||
"absolute bottom-[7px] right-2 border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
||||
(disabled || !value.trim()) && "opacity-20",
|
||||
)}
|
||||
disabled={disabled || !value.trim()}
|
||||
>
|
||||
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
{showMicButton && (
|
||||
<div className="absolute bottom-[7px] left-2 flex items-center gap-1">
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
onClick={toggleRecording}
|
||||
disabled={disabled || isTranscribing}
|
||||
className={cn(
|
||||
isRecording
|
||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||
: isTranscribing
|
||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||
)}
|
||||
>
|
||||
{isTranscribing ? (
|
||||
<CircleNotchIcon className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<MicrophoneIcon className="h-4 w-4" weight="bold" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="absolute bottom-[7px] right-2 flex items-center gap-1">
|
||||
{isStreaming ? (
|
||||
<Button
|
||||
type="button"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Stop generating"
|
||||
onClick={onStop}
|
||||
className="border-red-600 bg-red-600 text-white hover:border-red-800 hover:bg-red-800"
|
||||
>
|
||||
<StopIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type="submit"
|
||||
variant="icon"
|
||||
size="icon"
|
||||
aria-label="Send message"
|
||||
className={cn(
|
||||
"border-zinc-800 bg-zinc-800 text-white hover:border-zinc-900 hover:bg-zinc-900",
|
||||
(disabled || !value.trim() || isRecording) && "opacity-20",
|
||||
)}
|
||||
disabled={disabled || !value.trim() || isRecording}
|
||||
>
|
||||
<ArrowUpIcon className="h-4 w-4" weight="bold" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
interface Props {
|
||||
stream: MediaStream | null;
|
||||
barCount?: number;
|
||||
barWidth?: number;
|
||||
barGap?: number;
|
||||
barColor?: string;
|
||||
minBarHeight?: number;
|
||||
maxBarHeight?: number;
|
||||
}
|
||||
|
||||
export function AudioWaveform({
|
||||
stream,
|
||||
barCount = 24,
|
||||
barWidth = 3,
|
||||
barGap = 2,
|
||||
barColor = "#ef4444", // red-500
|
||||
minBarHeight = 4,
|
||||
maxBarHeight = 32,
|
||||
}: Props) {
|
||||
const [bars, setBars] = useState<number[]>(() =>
|
||||
Array(barCount).fill(minBarHeight),
|
||||
);
|
||||
const analyserRef = useRef<AnalyserNode | null>(null);
|
||||
const audioContextRef = useRef<AudioContext | null>(null);
|
||||
const sourceRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
||||
const animationRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!stream) {
|
||||
setBars(Array(barCount).fill(minBarHeight));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create audio context and analyser
|
||||
const audioContext = new AudioContext();
|
||||
const analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 512;
|
||||
analyser.smoothingTimeConstant = 0.8;
|
||||
|
||||
// Connect the stream to the analyser
|
||||
const source = audioContext.createMediaStreamSource(stream);
|
||||
source.connect(analyser);
|
||||
|
||||
audioContextRef.current = audioContext;
|
||||
analyserRef.current = analyser;
|
||||
sourceRef.current = source;
|
||||
|
||||
const timeData = new Uint8Array(analyser.frequencyBinCount);
|
||||
|
||||
const updateBars = () => {
|
||||
if (!analyserRef.current) return;
|
||||
|
||||
analyserRef.current.getByteTimeDomainData(timeData);
|
||||
|
||||
// Distribute time-domain data across bars
|
||||
// This shows waveform amplitude, making all bars respond to audio
|
||||
const newBars: number[] = [];
|
||||
const samplesPerBar = timeData.length / barCount;
|
||||
|
||||
for (let i = 0; i < barCount; i++) {
|
||||
// Sample waveform data for this bar
|
||||
let maxAmplitude = 0;
|
||||
const startIdx = Math.floor(i * samplesPerBar);
|
||||
const endIdx = Math.floor((i + 1) * samplesPerBar);
|
||||
|
||||
for (let j = startIdx; j < endIdx && j < timeData.length; j++) {
|
||||
// Convert to amplitude (distance from center 128)
|
||||
const amplitude = Math.abs(timeData[j] - 128);
|
||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||
}
|
||||
|
||||
// Map amplitude (0-128) to bar height
|
||||
const normalized = (maxAmplitude / 128) * 255;
|
||||
const height =
|
||||
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
||||
newBars.push(height);
|
||||
}
|
||||
|
||||
setBars(newBars);
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
};
|
||||
|
||||
updateBars();
|
||||
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
}
|
||||
if (sourceRef.current) {
|
||||
sourceRef.current.disconnect();
|
||||
}
|
||||
if (audioContextRef.current) {
|
||||
audioContextRef.current.close();
|
||||
}
|
||||
analyserRef.current = null;
|
||||
audioContextRef.current = null;
|
||||
sourceRef.current = null;
|
||||
};
|
||||
}, [stream, barCount, minBarHeight, maxBarHeight]);
|
||||
|
||||
const totalWidth = barCount * barWidth + (barCount - 1) * barGap;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex items-center justify-center"
|
||||
style={{
|
||||
width: totalWidth,
|
||||
height: maxBarHeight,
|
||||
gap: barGap,
|
||||
}}
|
||||
>
|
||||
{bars.map((height, i) => {
|
||||
const barHeight = Math.max(minBarHeight, height);
|
||||
return (
|
||||
<div
|
||||
key={i}
|
||||
className="relative"
|
||||
style={{
|
||||
width: barWidth,
|
||||
height: maxBarHeight,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="absolute left-0 rounded-full transition-[height] duration-75"
|
||||
style={{
|
||||
width: barWidth,
|
||||
height: barHeight,
|
||||
top: "50%",
|
||||
transform: "translateY(-50%)",
|
||||
backgroundColor: barColor,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
import { formatElapsedTime } from "../helpers";
|
||||
import { AudioWaveform } from "./AudioWaveform";
|
||||
|
||||
type Props = {
|
||||
elapsedTime: number;
|
||||
audioStream: MediaStream | null;
|
||||
};
|
||||
|
||||
export function RecordingIndicator({ elapsedTime, audioStream }: Props) {
|
||||
return (
|
||||
<div className="flex items-center gap-3">
|
||||
<AudioWaveform
|
||||
stream={audioStream}
|
||||
barCount={20}
|
||||
barWidth={3}
|
||||
barGap={2}
|
||||
barColor="#ef4444"
|
||||
minBarHeight={4}
|
||||
maxBarHeight={24}
|
||||
/>
|
||||
<span className="min-w-[3ch] text-sm font-medium text-red-500">
|
||||
{formatElapsedTime(elapsedTime)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
export function formatElapsedTime(ms: number): string {
|
||||
const seconds = Math.floor(ms / 1000);
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
const remainingSeconds = seconds % 60;
|
||||
return `${minutes}:${remainingSeconds.toString().padStart(2, "0")}`;
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
interface UseChatInputArgs {
|
||||
interface Args {
|
||||
onSend: (message: string) => void;
|
||||
disabled?: boolean;
|
||||
maxRows?: number;
|
||||
@@ -18,7 +18,7 @@ export function useChatInput({
|
||||
disabled = false,
|
||||
maxRows = 5,
|
||||
inputId = "chat-input",
|
||||
}: UseChatInputArgs) {
|
||||
}: Args) {
|
||||
const [value, setValue] = useState("");
|
||||
const [hasMultipleLines, setHasMultipleLines] = useState(false);
|
||||
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import React, {
|
||||
KeyboardEvent,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
const MAX_RECORDING_DURATION = 2 * 60 * 1000; // 2 minutes in ms
|
||||
|
||||
interface Args {
|
||||
setValue: React.Dispatch<React.SetStateAction<string>>;
|
||||
disabled?: boolean;
|
||||
isStreaming?: boolean;
|
||||
value: string;
|
||||
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
||||
}
|
||||
|
||||
export function useVoiceRecording({
|
||||
setValue,
|
||||
disabled = false,
|
||||
isStreaming = false,
|
||||
value,
|
||||
baseHandleKeyDown,
|
||||
}: Args) {
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isTranscribing, setIsTranscribing] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [elapsedTime, setElapsedTime] = useState(0);
|
||||
|
||||
const mediaRecorderRef = useRef<MediaRecorder | null>(null);
|
||||
const chunksRef = useRef<Blob[]>([]);
|
||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const startTimeRef = useRef<number>(0);
|
||||
const streamRef = useRef<MediaStream | null>(null);
|
||||
const isRecordingRef = useRef(false);
|
||||
|
||||
const isSupported =
|
||||
typeof window !== "undefined" &&
|
||||
!!(navigator.mediaDevices && navigator.mediaDevices.getUserMedia);
|
||||
|
||||
const clearTimer = useCallback(() => {
|
||||
if (timerRef.current) {
|
||||
clearInterval(timerRef.current);
|
||||
timerRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const cleanup = useCallback(() => {
|
||||
clearTimer();
|
||||
if (streamRef.current) {
|
||||
streamRef.current.getTracks().forEach((track) => track.stop());
|
||||
streamRef.current = null;
|
||||
}
|
||||
mediaRecorderRef.current = null;
|
||||
chunksRef.current = [];
|
||||
setElapsedTime(0);
|
||||
}, [clearTimer]);
|
||||
|
||||
const handleTranscription = useCallback(
|
||||
(text: string) => {
|
||||
setValue((prev) => {
|
||||
const trimmedPrev = prev.trim();
|
||||
if (trimmedPrev) {
|
||||
return `${trimmedPrev} ${text}`;
|
||||
}
|
||||
return text;
|
||||
});
|
||||
},
|
||||
[setValue],
|
||||
);
|
||||
|
||||
const transcribeAudio = useCallback(
|
||||
async (audioBlob: Blob) => {
|
||||
setIsTranscribing(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append("audio", audioBlob);
|
||||
|
||||
const response = await fetch("/api/transcribe", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json().catch(() => ({}));
|
||||
throw new Error(data.error || "Transcription failed");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (data.text) {
|
||||
handleTranscription(data.text);
|
||||
}
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Transcription failed";
|
||||
setError(message);
|
||||
console.error("Transcription error:", err);
|
||||
} finally {
|
||||
setIsTranscribing(false);
|
||||
}
|
||||
},
|
||||
[handleTranscription],
|
||||
);
|
||||
|
||||
const stopRecording = useCallback(() => {
|
||||
if (mediaRecorderRef.current && isRecordingRef.current) {
|
||||
mediaRecorderRef.current.stop();
|
||||
isRecordingRef.current = false;
|
||||
setIsRecording(false);
|
||||
clearTimer();
|
||||
}
|
||||
}, [clearTimer]);
|
||||
|
||||
const startRecording = useCallback(async () => {
|
||||
if (disabled || isRecordingRef.current || isTranscribing) return;
|
||||
|
||||
setError(null);
|
||||
chunksRef.current = [];
|
||||
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
streamRef.current = stream;
|
||||
|
||||
const mediaRecorder = new MediaRecorder(stream, {
|
||||
mimeType: MediaRecorder.isTypeSupported("audio/webm")
|
||||
? "audio/webm"
|
||||
: "audio/mp4",
|
||||
});
|
||||
|
||||
mediaRecorderRef.current = mediaRecorder;
|
||||
|
||||
mediaRecorder.ondataavailable = (event) => {
|
||||
if (event.data.size > 0) {
|
||||
chunksRef.current.push(event.data);
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.onstop = async () => {
|
||||
const audioBlob = new Blob(chunksRef.current, {
|
||||
type: mediaRecorder.mimeType,
|
||||
});
|
||||
|
||||
// Cleanup stream
|
||||
if (streamRef.current) {
|
||||
streamRef.current.getTracks().forEach((track) => track.stop());
|
||||
streamRef.current = null;
|
||||
}
|
||||
|
||||
if (audioBlob.size > 0) {
|
||||
await transcribeAudio(audioBlob);
|
||||
}
|
||||
};
|
||||
|
||||
mediaRecorder.start(1000); // Collect data every second
|
||||
isRecordingRef.current = true;
|
||||
setIsRecording(true);
|
||||
startTimeRef.current = Date.now();
|
||||
|
||||
// Start elapsed time timer
|
||||
timerRef.current = setInterval(() => {
|
||||
const elapsed = Date.now() - startTimeRef.current;
|
||||
setElapsedTime(elapsed);
|
||||
|
||||
// Auto-stop at max duration
|
||||
if (elapsed >= MAX_RECORDING_DURATION) {
|
||||
stopRecording();
|
||||
}
|
||||
}, 100);
|
||||
} catch (err) {
|
||||
console.error("Failed to start recording:", err);
|
||||
if (err instanceof DOMException && err.name === "NotAllowedError") {
|
||||
setError("Microphone permission denied");
|
||||
} else {
|
||||
setError("Failed to access microphone");
|
||||
}
|
||||
cleanup();
|
||||
}
|
||||
}, [disabled, isTranscribing, stopRecording, transcribeAudio, cleanup]);
|
||||
|
||||
const toggleRecording = useCallback(() => {
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
} else {
|
||||
startRecording();
|
||||
}
|
||||
}, [isRecording, startRecording, stopRecording]);
|
||||
|
||||
const { toast } = useToast();
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
toast({
|
||||
title: "Voice recording failed",
|
||||
description: error,
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
}, [error, toast]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (event.key === " " && !value.trim() && !isTranscribing) {
|
||||
event.preventDefault();
|
||||
toggleRecording();
|
||||
return;
|
||||
}
|
||||
baseHandleKeyDown(event);
|
||||
},
|
||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||
);
|
||||
|
||||
const showMicButton = isSupported && !isStreaming;
|
||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
cleanup();
|
||||
};
|
||||
}, [cleanup]);
|
||||
|
||||
return {
|
||||
isRecording,
|
||||
isTranscribing,
|
||||
error,
|
||||
elapsedTime,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
toggleRecording,
|
||||
isSupported,
|
||||
handleKeyDown,
|
||||
showMicButton,
|
||||
isInputDisabled,
|
||||
audioStream: streamRef.current,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { EyeSlash } from "@phosphor-icons/react";
|
||||
import React from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -29,12 +31,88 @@ interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {
|
||||
type?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a workspace:// URL to a proxy URL that routes through Next.js to the backend.
|
||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||
*
|
||||
* Uses the generated API URL helper and routes through the Next.js proxy
|
||||
* which handles authentication and proper backend routing.
|
||||
*/
|
||||
/**
|
||||
* URL transformer for ReactMarkdown.
|
||||
* Converts workspace:// URLs to proxy URLs that route through Next.js to the backend.
|
||||
* workspace://abc123 -> /api/proxy/api/workspace/files/abc123/download
|
||||
*
|
||||
* This is needed because ReactMarkdown sanitizes URLs and only allows
|
||||
* http, https, mailto, and tel protocols by default.
|
||||
*/
|
||||
function resolveWorkspaceUrl(src: string): string {
|
||||
if (src.startsWith("workspace://")) {
|
||||
const fileId = src.replace("workspace://", "");
|
||||
// Use the generated API URL helper to get the correct path
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(fileId);
|
||||
// Route through the Next.js proxy (same pattern as customMutator for client-side)
|
||||
return `/api/proxy${apiPath}`;
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the image URL is a workspace file (AI cannot see these yet).
|
||||
* After URL transformation, workspace files have URLs like /api/proxy/api/workspace/files/...
|
||||
*/
|
||||
function isWorkspaceImage(src: string | undefined): boolean {
|
||||
return src?.includes("/workspace/files/") ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom image component that shows an indicator when the AI cannot see the image.
|
||||
* Note: src is already transformed by urlTransform, so workspace:// is now /api/workspace/...
|
||||
*/
|
||||
function MarkdownImage(props: Record<string, unknown>) {
|
||||
const src = props.src as string | undefined;
|
||||
const alt = props.alt as string | undefined;
|
||||
|
||||
const aiCannotSee = isWorkspaceImage(src);
|
||||
|
||||
// If no src, show a placeholder
|
||||
if (!src) {
|
||||
return (
|
||||
<span className="my-2 inline-block rounded border border-amber-200 bg-amber-50 px-2 py-1 text-sm text-amber-700">
|
||||
[Image: {alt || "missing src"}]
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="relative my-2 inline-block">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt || "Image"}
|
||||
className="h-auto max-w-full rounded-md border border-zinc-200"
|
||||
loading="lazy"
|
||||
/>
|
||||
{aiCannotSee && (
|
||||
<span
|
||||
className="absolute bottom-2 right-2 flex items-center gap-1 rounded bg-black/70 px-2 py-1 text-xs text-white"
|
||||
title="The AI cannot see this image"
|
||||
>
|
||||
<EyeSlash size={14} />
|
||||
<span>AI cannot see this image</span>
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
return (
|
||||
<div className={cn("markdown-content", className)}>
|
||||
<ReactMarkdown
|
||||
skipHtml={true}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
urlTransform={resolveWorkspaceUrl}
|
||||
components={{
|
||||
code: ({ children, className, ...props }: CodeProps) => {
|
||||
const isInline = !className?.includes("language-");
|
||||
@@ -206,6 +284,9 @@ export function MarkdownContent({ content, className }: MarkdownContentProps) {
|
||||
{children}
|
||||
</td>
|
||||
),
|
||||
img: ({ src, alt, ...props }) => (
|
||||
<MarkdownImage src={src} alt={alt} {...props} />
|
||||
),
|
||||
}}
|
||||
>
|
||||
{content}
|
||||
|
||||
@@ -37,6 +37,87 @@ export function getErrorMessage(result: unknown): string {
|
||||
return "An error occurred";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a value is a workspace file reference.
|
||||
*/
|
||||
function isWorkspaceRef(value: unknown): value is string {
|
||||
return typeof value === "string" && value.startsWith("workspace://");
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a workspace reference appears to be an image based on common patterns.
|
||||
* Since workspace refs don't have extensions, we check the context or assume image
|
||||
* for certain block types.
|
||||
*
|
||||
* TODO: Replace keyword matching with MIME type encoded in workspace ref.
|
||||
* e.g., workspace://abc123#image/png or workspace://abc123#video/mp4
|
||||
* This would let frontend render correctly without fragile keyword matching.
|
||||
*/
|
||||
function isLikelyImageRef(value: string, outputKey?: string): boolean {
|
||||
if (!isWorkspaceRef(value)) return false;
|
||||
|
||||
// Check output key name for video-related hints (these are NOT images)
|
||||
const videoKeywords = ["video", "mp4", "mov", "avi", "webm", "movie", "clip"];
|
||||
if (outputKey) {
|
||||
const lowerKey = outputKey.toLowerCase();
|
||||
if (videoKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check output key name for image-related hints
|
||||
const imageKeywords = [
|
||||
"image",
|
||||
"img",
|
||||
"photo",
|
||||
"picture",
|
||||
"thumbnail",
|
||||
"avatar",
|
||||
"icon",
|
||||
"screenshot",
|
||||
];
|
||||
if (outputKey) {
|
||||
const lowerKey = outputKey.toLowerCase();
|
||||
if (imageKeywords.some((kw) => lowerKey.includes(kw))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Default to treating workspace refs as potential images
|
||||
// since that's the most common case for generated content
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a single output value, converting workspace refs to markdown images.
|
||||
*/
|
||||
function formatOutputValue(value: unknown, outputKey?: string): string {
|
||||
if (isWorkspaceRef(value) && isLikelyImageRef(value, outputKey)) {
|
||||
// Format as markdown image
|
||||
return ``;
|
||||
}
|
||||
|
||||
if (typeof value === "string") {
|
||||
// Check for data URIs (images)
|
||||
if (value.startsWith("data:image/")) {
|
||||
return ``;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return value
|
||||
.map((item, idx) => formatOutputValue(item, `${outputKey}_${idx}`))
|
||||
.join("\n\n");
|
||||
}
|
||||
|
||||
if (typeof value === "object" && value !== null) {
|
||||
return JSON.stringify(value, null, 2);
|
||||
}
|
||||
|
||||
return String(value);
|
||||
}
|
||||
|
||||
function getToolCompletionPhrase(toolName: string): string {
|
||||
const toolCompletionPhrases: Record<string, string> = {
|
||||
add_understanding: "Updated your business information",
|
||||
@@ -127,10 +208,26 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
||||
|
||||
case "block_output":
|
||||
const blockName = (response.block_name as string) || "Block";
|
||||
const outputs = response.outputs as Record<string, unknown> | undefined;
|
||||
const outputs = response.outputs as Record<string, unknown[]> | undefined;
|
||||
if (outputs && Object.keys(outputs).length > 0) {
|
||||
const outputKeys = Object.keys(outputs);
|
||||
return `${blockName} executed successfully. Outputs: ${outputKeys.join(", ")}`;
|
||||
const formattedOutputs: string[] = [];
|
||||
|
||||
for (const [key, values] of Object.entries(outputs)) {
|
||||
if (!Array.isArray(values) || values.length === 0) continue;
|
||||
|
||||
// Format each value in the output array
|
||||
for (const value of values) {
|
||||
const formatted = formatOutputValue(value, key);
|
||||
if (formatted) {
|
||||
formattedOutputs.push(formatted);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (formattedOutputs.length > 0) {
|
||||
return `${blockName} executed successfully.\n\n${formattedOutputs.join("\n\n")}`;
|
||||
}
|
||||
return `${blockName} executed successfully.`;
|
||||
}
|
||||
return `${blockName} executed successfully.`;
|
||||
|
||||
|
||||
@@ -516,7 +516,7 @@ export type GraphValidationErrorResponse = {
|
||||
|
||||
/* *** LIBRARY *** */
|
||||
|
||||
/* Mirror of backend/server/v2/library/model.py:LibraryAgent */
|
||||
/* Mirror of backend/api/features/library/model.py:LibraryAgent */
|
||||
export type LibraryAgent = {
|
||||
id: LibraryAgentID;
|
||||
graph_id: GraphID;
|
||||
@@ -616,7 +616,7 @@ export enum LibraryAgentSortEnum {
|
||||
|
||||
/* *** CREDENTIALS *** */
|
||||
|
||||
/* Mirror of backend/server/integrations/router.py:CredentialsMetaResponse */
|
||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsMetaResponse */
|
||||
export type CredentialsMetaResponse = {
|
||||
id: string;
|
||||
provider: CredentialsProviderName;
|
||||
@@ -628,13 +628,13 @@ export type CredentialsMetaResponse = {
|
||||
is_system?: boolean;
|
||||
};
|
||||
|
||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionResponse */
|
||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionResponse */
|
||||
export type CredentialsDeleteResponse = {
|
||||
deleted: true;
|
||||
revoked: boolean | null;
|
||||
};
|
||||
|
||||
/* Mirror of backend/server/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
||||
/* Mirror of backend/api/features/integrations/router.py:CredentialsDeletionNeedsConfirmationResponse */
|
||||
export type CredentialsDeleteNeedConfirmationResponse = {
|
||||
deleted: false;
|
||||
need_confirmation: true;
|
||||
@@ -888,7 +888,7 @@ export type Schedule = {
|
||||
|
||||
export type ScheduleID = Brand<string, "ScheduleID">;
|
||||
|
||||
/* Mirror of backend/server/routers/v1.py:ScheduleCreationRequest */
|
||||
/* Mirror of backend/api/features/v1.py:ScheduleCreationRequest */
|
||||
export type ScheduleCreatable = {
|
||||
graph_id: GraphID;
|
||||
graph_version: number;
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
extend-ignore = E203
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
*.pyc,
|
||||
.env
|
||||
venv*/*,
|
||||
.venv/*,
|
||||
reports/*,
|
||||
dist/*,
|
||||
data/*,
|
||||
.env,
|
||||
venv*,
|
||||
.venv,
|
||||
reports,
|
||||
dist,
|
||||
data,
|
||||
.benchmark_workspaces,
|
||||
.autogpt,
|
||||
|
||||
291
classic/CLAUDE.md
Normal file
291
classic/CLAUDE.md
Normal file
@@ -0,0 +1,291 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
AutoGPT Classic is an experimental, **unsupported** project demonstrating autonomous GPT-4 operation. Dependencies will not be updated, and the codebase contains known vulnerabilities. This is preserved for educational/historical purposes.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
```
|
||||
classic/
|
||||
├── pyproject.toml # Single consolidated Poetry project
|
||||
├── poetry.lock # Single lock file
|
||||
├── forge/
|
||||
│ └── forge/ # Core agent framework package
|
||||
├── original_autogpt/
|
||||
│ └── autogpt/ # AutoGPT agent package
|
||||
├── direct_benchmark/
|
||||
│ └── direct_benchmark/ # Benchmark harness package
|
||||
└── benchmark/ # Challenge definitions (data, not code)
|
||||
```
|
||||
|
||||
All packages are managed by a single `pyproject.toml` at the classic/ root.
|
||||
|
||||
## Common Commands
|
||||
|
||||
### Setup & Install
|
||||
```bash
|
||||
# Install everything from classic/ directory
|
||||
cd classic
|
||||
poetry install
|
||||
```
|
||||
|
||||
### Running Agents
|
||||
```bash
|
||||
# Run forge agent
|
||||
poetry run python -m forge
|
||||
|
||||
# Run original autogpt server
|
||||
poetry run serve --debug
|
||||
|
||||
# Run autogpt CLI
|
||||
poetry run autogpt
|
||||
```
|
||||
|
||||
Agents run on `http://localhost:8000` by default.
|
||||
|
||||
### Benchmarking
|
||||
```bash
|
||||
# Run benchmarks
|
||||
poetry run direct-benchmark run
|
||||
|
||||
# Run specific strategies and models
|
||||
poetry run direct-benchmark run \
|
||||
--strategies one_shot,rewoo \
|
||||
--models claude \
|
||||
--parallel 4
|
||||
|
||||
# Run a single test
|
||||
poetry run direct-benchmark run --tests ReadFile
|
||||
|
||||
# List available commands
|
||||
poetry run direct-benchmark --help
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
poetry run pytest # All tests
|
||||
poetry run pytest forge/tests/ # Forge tests only
|
||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
||||
poetry run pytest -k test_name # Single test by name
|
||||
poetry run pytest path/to/test.py # Specific test file
|
||||
poetry run pytest --cov # With coverage
|
||||
```
|
||||
|
||||
### Linting & Formatting
|
||||
|
||||
Run from the classic/ directory:
|
||||
|
||||
```bash
|
||||
# Format everything (recommended to run together)
|
||||
poetry run black . && poetry run isort .
|
||||
|
||||
# Check formatting (CI-style, no changes)
|
||||
poetry run black --check . && poetry run isort --check-only .
|
||||
|
||||
# Lint
|
||||
poetry run flake8 # Style linting
|
||||
|
||||
# Type check
|
||||
poetry run pyright # Type checking (some errors are expected in infrastructure code)
|
||||
```
|
||||
|
||||
Note: Always run linters over the entire directory, not specific files, for best results.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Forge (Core Framework)
|
||||
The `forge` package is the foundation that other components depend on:
|
||||
- `forge/agent/` - Agent implementation and protocols
|
||||
- `forge/llm/` - Multi-provider LLM integrations (OpenAI, Anthropic, Groq, LiteLLM)
|
||||
- `forge/components/` - Reusable agent components
|
||||
- `forge/file_storage/` - File system abstraction
|
||||
- `forge/config/` - Configuration management
|
||||
|
||||
### Original AutoGPT
|
||||
- `original_autogpt/autogpt/app/` - CLI application entry points
|
||||
- `original_autogpt/autogpt/agents/` - Agent implementations
|
||||
- `original_autogpt/autogpt/agent_factory/` - Agent creation logic
|
||||
|
||||
### Direct Benchmark
|
||||
Benchmark harness for testing agent performance:
|
||||
- `direct_benchmark/direct_benchmark/` - CLI and harness code
|
||||
- `benchmark/agbenchmark/challenges/` - Test cases organized by category (code, retrieval, data, etc.)
|
||||
- Reports generated in `direct_benchmark/reports/`
|
||||
|
||||
### Package Structure
|
||||
All three packages are included in a single Poetry project. Imports are fully qualified:
|
||||
- `from forge.agent.base import BaseAgent`
|
||||
- `from autogpt.agents.agent import Agent`
|
||||
- `from direct_benchmark.harness import BenchmarkHarness`
|
||||
|
||||
## Code Style
|
||||
|
||||
- Python 3.12 target
|
||||
- Line length: 88 characters (Black default)
|
||||
- Black for formatting, isort for imports (profile="black")
|
||||
- Type hints with Pyright checking
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
- Async support via pytest-asyncio
|
||||
- Fixtures defined in `conftest.py` files provide: `tmp_project_root`, `storage`, `config`, `llm_provider`, `agent`
|
||||
- Tests requiring API keys (OPENAI_API_KEY, ANTHROPIC_API_KEY) will skip if not set
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Copy `.env.example` to `.env` in the relevant directory and add your API keys:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your OPENAI_API_KEY, etc.
|
||||
```
|
||||
|
||||
## Workspaces
|
||||
|
||||
Agents operate within a **workspace** - a directory containing all agent data and files. The workspace root defaults to the current working directory.
|
||||
|
||||
### Workspace Structure
|
||||
|
||||
```
|
||||
{workspace}/
|
||||
├── .autogpt/
|
||||
│ ├── autogpt.yaml # Workspace-level permissions
|
||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
||||
│ └── agents/
|
||||
│ └── AutoGPT-{agent_id}/
|
||||
│ ├── state.json # Agent profile, directives, action history
|
||||
│ ├── permissions.yaml # Agent-specific permission overrides
|
||||
│ └── workspace/ # Agent's sandboxed working directory
|
||||
```
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Multiple agents** can coexist in the same workspace (each gets its own subdirectory)
|
||||
- **File access** is sandboxed to the agent's `workspace/` directory by default
|
||||
- **State persistence** - agent state saves to `state.json` and survives across sessions
|
||||
- **Storage backends** - supports local filesystem, S3, and GCS (via `FILE_STORAGE_BACKEND` env var)
|
||||
|
||||
### Specifying a Workspace
|
||||
|
||||
```bash
|
||||
# Default: uses current directory
|
||||
cd /path/to/my/project && poetry run autogpt
|
||||
|
||||
# Or specify explicitly via CLI (if supported)
|
||||
poetry run autogpt --workspace /path/to/workspace
|
||||
```
|
||||
|
||||
## Settings Location
|
||||
|
||||
Configuration uses a **layered system** with three levels (in order of precedence):
|
||||
|
||||
### 1. Environment Variables (Global)
|
||||
|
||||
Loaded from `.env` file in the working directory:
|
||||
|
||||
```bash
|
||||
# Required
|
||||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# Optional LLM settings
|
||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
# Optional search providers (for web search component)
|
||||
TAVILY_API_KEY=tvly-...
|
||||
SERPER_API_KEY=...
|
||||
GOOGLE_API_KEY=...
|
||||
GOOGLE_CUSTOM_SEARCH_ENGINE_ID=...
|
||||
|
||||
# Optional infrastructure
|
||||
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR
|
||||
DATABASE_STRING=sqlite:///agent.db # Agent Protocol database
|
||||
PORT=8000 # Server port
|
||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
||||
```
|
||||
|
||||
### 2. Workspace Settings (`{workspace}/.autogpt/autogpt.yaml`)
|
||||
|
||||
Workspace-wide permissions that apply to **all agents** in this workspace:
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- read_file({workspace}/**)
|
||||
- write_to_file({workspace}/**)
|
||||
- list_folder({workspace}/**)
|
||||
- web_search(*)
|
||||
|
||||
deny:
|
||||
- read_file(**.env)
|
||||
- read_file(**.env.*)
|
||||
- read_file(**.key)
|
||||
- read_file(**.pem)
|
||||
- execute_shell(rm -rf:*)
|
||||
- execute_shell(sudo:*)
|
||||
```
|
||||
|
||||
Auto-generated with sensible defaults if missing.
|
||||
|
||||
### 3. Agent Settings (`{workspace}/.autogpt/agents/{id}/permissions.yaml`)
|
||||
|
||||
Agent-specific permission overrides:
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- execute_python(*)
|
||||
- web_search(*)
|
||||
|
||||
deny:
|
||||
- execute_shell(*)
|
||||
```
|
||||
|
||||
## Permissions
|
||||
|
||||
The permission system uses **pattern matching** with a **first-match-wins** evaluation order.
|
||||
|
||||
### Permission Check Order
|
||||
|
||||
1. Agent deny list → **Block**
|
||||
2. Workspace deny list → **Block**
|
||||
3. Agent allow list → **Allow**
|
||||
4. Workspace allow list → **Allow**
|
||||
5. Session denied list → **Block** (commands denied during this session)
|
||||
6. **Prompt user** → Interactive approval (if in interactive mode)
|
||||
|
||||
### Pattern Syntax
|
||||
|
||||
Format: `command_name(glob_pattern)`
|
||||
|
||||
| Pattern | Description |
|
||||
|---------|-------------|
|
||||
| `read_file({workspace}/**)` | Read any file in workspace (recursive) |
|
||||
| `write_to_file({workspace}/*.txt)` | Write only .txt files in workspace root |
|
||||
| `execute_shell(python:**)` | Execute Python commands only |
|
||||
| `execute_shell(git:*)` | Execute any git command |
|
||||
| `web_search(*)` | Allow all web searches |
|
||||
|
||||
Special tokens:
|
||||
- `{workspace}` - Replaced with actual workspace path
|
||||
- `**` - Matches any path including `/`
|
||||
- `*` - Matches any characters except `/`
|
||||
|
||||
### Interactive Approval Scopes
|
||||
|
||||
When prompted for permission, users can choose:
|
||||
|
||||
| Scope | Effect |
|
||||
|-------|--------|
|
||||
| **Once** | Allow this one time only (not saved) |
|
||||
| **Agent** | Always allow for this agent (saves to agent `permissions.yaml`) |
|
||||
| **Workspace** | Always allow for all agents (saves to `autogpt.yaml`) |
|
||||
| **Deny** | Deny this command (saves to appropriate deny list) |
|
||||
|
||||
### Default Security
|
||||
|
||||
Out of the box, the following are **denied by default**:
|
||||
- Reading sensitive files (`.env`, `.key`, `.pem`)
|
||||
- Destructive shell commands (`rm -rf`, `sudo`)
|
||||
- Operations outside the workspace directory
|
||||
@@ -1,182 +0,0 @@
|
||||
## CLI Documentation
|
||||
|
||||
This document describes how to interact with the project's CLI (Command Line Interface). It includes the types of outputs you can expect from each command. Note that the `agents stop` command will terminate any process running on port 8000.
|
||||
|
||||
### 1. Entry Point for the CLI
|
||||
|
||||
Running the `./run` command without any parameters will display the help message, which provides a list of available commands and options. Additionally, you can append `--help` to any command to view help information specific to that command.
|
||||
|
||||
```sh
|
||||
./run
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Usage: cli.py [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
agent Commands to create, start and stop agents
|
||||
benchmark Commands to start the benchmark and list tests and categories
|
||||
setup Installs dependencies needed for your system.
|
||||
```
|
||||
|
||||
If you need assistance with any command, simply add the `--help` parameter to the end of your command, like so:
|
||||
|
||||
```sh
|
||||
./run COMMAND --help
|
||||
```
|
||||
|
||||
This will display a detailed help message regarding that specific command, including a list of any additional options and arguments it accepts.
|
||||
|
||||
### 2. Setup Command
|
||||
|
||||
```sh
|
||||
./run setup
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Setup initiated
|
||||
Installation has been completed.
|
||||
```
|
||||
|
||||
This command initializes the setup of the project.
|
||||
|
||||
### 3. Agents Commands
|
||||
|
||||
**a. List All Agents**
|
||||
|
||||
```sh
|
||||
./run agent list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available agents: 🤖
|
||||
🐙 forge
|
||||
🐙 autogpt
|
||||
```
|
||||
|
||||
Lists all the available agents.
|
||||
|
||||
**b. Create a New Agent**
|
||||
|
||||
```sh
|
||||
./run agent create my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
🎉 New agent 'my_agent' created and switched to the new directory in agents folder.
|
||||
```
|
||||
|
||||
Creates a new agent named 'my_agent'.
|
||||
|
||||
**c. Start an Agent**
|
||||
|
||||
```sh
|
||||
./run agent start my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
... (ASCII Art representing the agent startup)
|
||||
[Date and Time] [forge.sdk.db] [DEBUG] 🐛 Initializing AgentDB with database_string: sqlite:///agent.db
|
||||
[Date and Time] [forge.sdk.agent] [INFO] 📝 Agent server starting on http://0.0.0.0:8000
|
||||
```
|
||||
|
||||
Starts the 'my_agent' and displays startup ASCII art and logs.
|
||||
|
||||
**d. Stop an Agent**
|
||||
|
||||
```sh
|
||||
./run agent stop
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Agent stopped
|
||||
```
|
||||
|
||||
Stops the running agent.
|
||||
|
||||
### 4. Benchmark Commands
|
||||
|
||||
**a. List Benchmark Categories**
|
||||
|
||||
```sh
|
||||
./run benchmark categories list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available categories: 📚
|
||||
📖 code
|
||||
📖 safety
|
||||
📖 memory
|
||||
... (and so on)
|
||||
```
|
||||
|
||||
Lists all available benchmark categories.
|
||||
|
||||
**b. List Benchmark Tests**
|
||||
|
||||
```sh
|
||||
./run benchmark tests list
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
Available tests: 📚
|
||||
📖 interface
|
||||
🔬 Search - TestSearch
|
||||
🔬 Write File - TestWriteFile
|
||||
... (and so on)
|
||||
```
|
||||
|
||||
Lists all available benchmark tests.
|
||||
|
||||
**c. Show Details of a Benchmark Test**
|
||||
|
||||
```sh
|
||||
./run benchmark tests details TestWriteFile
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
TestWriteFile
|
||||
-------------
|
||||
|
||||
Category: interface
|
||||
Task: Write the word 'Washington' to a .txt file
|
||||
... (and other details)
|
||||
```
|
||||
|
||||
Displays the details of the 'TestWriteFile' benchmark test.
|
||||
|
||||
**d. Start Benchmark for the Agent**
|
||||
|
||||
```sh
|
||||
./run benchmark start my_agent
|
||||
```
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
(more details about the testing process shown whilst the test are running)
|
||||
============= 13 failed, 1 passed in 0.97s ============...
|
||||
```
|
||||
|
||||
Displays the results of the benchmark tests on 'my_agent'.
|
||||
@@ -2,7 +2,7 @@
|
||||
ARG BUILD_TYPE=dev
|
||||
|
||||
# Use an official Python base image from the Docker Hub
|
||||
FROM python:3.10-slim AS autogpt-base
|
||||
FROM python:3.12-slim AS autogpt-base
|
||||
|
||||
# Install browsers
|
||||
RUN apt-get update && apt-get install -y \
|
||||
@@ -34,9 +34,6 @@ COPY original_autogpt/pyproject.toml original_autogpt/poetry.lock ./
|
||||
# Include forge so it can be used as a path dependency
|
||||
COPY forge/ ../forge
|
||||
|
||||
# Include frontend
|
||||
COPY frontend/ ../frontend
|
||||
|
||||
# Set the entrypoint
|
||||
ENTRYPOINT ["poetry", "run", "autogpt"]
|
||||
CMD []
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
# Quickstart Guide
|
||||
|
||||
> For the complete getting started [tutorial series](https://aiedge.medium.com/autogpt-forge-e3de53cc58ec) <- click here
|
||||
|
||||
Welcome to the Quickstart Guide! This guide will walk you through setting up, building, and running your own AutoGPT agent. Whether you're a seasoned AI developer or just starting out, this guide will provide you with the steps to jumpstart your journey in AI development with AutoGPT.
|
||||
|
||||
## System Requirements
|
||||
|
||||
This project supports Linux (Debian-based), Mac, and Windows Subsystem for Linux (WSL). If you use a Windows system, you must install WSL. You can find the installation instructions for WSL [here](https://learn.microsoft.com/en-us/windows/wsl/).
|
||||
|
||||
|
||||
## Getting Setup
|
||||
1. **Fork the Repository**
|
||||
To fork the repository, follow these steps:
|
||||
- Navigate to the main page of the repository.
|
||||
|
||||

|
||||
- In the top-right corner of the page, click Fork.
|
||||
|
||||

|
||||
- On the next page, select your GitHub account to create the fork.
|
||||
- Wait for the forking process to complete. You now have a copy of the repository in your GitHub account.
|
||||
|
||||
2. **Clone the Repository**
|
||||
To clone the repository, you need to have Git installed on your system. If you don't have Git installed, download it from [here](https://git-scm.com/downloads). Once you have Git installed, follow these steps:
|
||||
- Open your terminal.
|
||||
- Navigate to the directory where you want to clone the repository.
|
||||
- Run the git clone command for the fork you just created
|
||||
|
||||

|
||||
|
||||
- Then open your project in your ide
|
||||
|
||||

|
||||
|
||||
4. **Setup the Project**
|
||||
Next, we need to set up the required dependencies. We have a tool to help you perform all the tasks on the repo.
|
||||
It can be accessed by running the `run` command by typing `./run` in the terminal.
|
||||
|
||||
The first command you need to use is `./run setup.` This will guide you through setting up your system.
|
||||
Initially, you will get instructions for installing Flutter and Chrome and setting up your GitHub access token like the following image:
|
||||
|
||||

|
||||
|
||||
### For Windows Users
|
||||
|
||||
If you're a Windows user and experience issues after installing WSL, follow the steps below to resolve them.
|
||||
|
||||
#### Update WSL
|
||||
Run the following command in Powershell or Command Prompt:
|
||||
1. Enable the optional WSL and Virtual Machine Platform components.
|
||||
2. Download and install the latest Linux kernel.
|
||||
3. Set WSL 2 as the default.
|
||||
4. Download and install the Ubuntu Linux distribution (a reboot may be required).
|
||||
|
||||
```shell
|
||||
wsl --install
|
||||
```
|
||||
|
||||
For more detailed information and additional steps, refer to [Microsoft's WSL Setup Environment Documentation](https://learn.microsoft.com/en-us/windows/wsl/setup/environment).
|
||||
|
||||
#### Resolve FileNotFoundError or "No such file or directory" Errors
|
||||
When you run `./run setup`, if you encounter errors like `No such file or directory` or `FileNotFoundError`, it might be because Windows-style line endings (CRLF - Carriage Return Line Feed) are not compatible with Unix/Linux style line endings (LF - Line Feed).
|
||||
|
||||
To resolve this, you can use the `dos2unix` utility to convert the line endings in your script from CRLF to LF. Here’s how to install and run `dos2unix` on the script:
|
||||
|
||||
```shell
|
||||
sudo apt update
|
||||
sudo apt install dos2unix
|
||||
dos2unix ./run
|
||||
```
|
||||
|
||||
After executing the above commands, running `./run setup` should work successfully.
|
||||
|
||||
#### Store Project Files within the WSL File System
|
||||
If you continue to experience issues, consider storing your project files within the WSL file system instead of the Windows file system. This method avoids path translations and permissions issues and provides a more consistent development environment.
|
||||
|
||||
You can keep running the command to get feedback on where you are up to with your setup.
|
||||
When setup has been completed, the command will return an output like this:
|
||||
|
||||

|
||||
|
||||
## Creating Your Agent
|
||||
|
||||
After completing the setup, the next step is to create your agent template.
|
||||
Execute the command `./run agent create YOUR_AGENT_NAME`, where `YOUR_AGENT_NAME` should be replaced with your chosen name.
|
||||
|
||||
Tips for naming your agent:
|
||||
* Give it its own unique name, or name it after yourself
|
||||
* Include an important aspect of your agent in the name, such as its purpose
|
||||
|
||||
Examples: `SwiftyosAssistant`, `PwutsPRAgent`, `MySuperAgent`
|
||||
|
||||

|
||||
|
||||
## Running your Agent
|
||||
|
||||
Your agent can be started using the command: `./run agent start YOUR_AGENT_NAME`
|
||||
|
||||
This starts the agent on the URL: `http://localhost:8000/`
|
||||
|
||||

|
||||
|
||||
The front end can be accessed from `http://localhost:8000/`; first, you must log in using either a Google account or your GitHub account.
|
||||
|
||||

|
||||
|
||||
Upon logging in, you will get a page that looks something like this: your task history down the left-hand side of the page, and the 'chat' window to send tasks to your agent.
|
||||
|
||||

|
||||
|
||||
When you have finished with your agent or just need to restart it, use Ctl-C to end the session. Then, you can re-run the start command.
|
||||
|
||||
If you are having issues and want to ensure the agent has been stopped, there is a `./run agent stop` command, which will kill the process using port 8000, which should be the agent.
|
||||
|
||||
## Benchmarking your Agent
|
||||
|
||||
The benchmarking system can also be accessed using the CLI too:
|
||||
|
||||
```bash
|
||||
agpt % ./run benchmark
|
||||
Usage: cli.py benchmark [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Commands to start the benchmark and list tests and categories
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
categories Benchmark categories group command
|
||||
start Starts the benchmark command
|
||||
tests Benchmark tests group command
|
||||
agpt % ./run benchmark categories
|
||||
Usage: cli.py benchmark categories [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Benchmark categories group command
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
list List benchmark categories command
|
||||
agpt % ./run benchmark tests
|
||||
Usage: cli.py benchmark tests [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Benchmark tests group command
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
details Benchmark test details command
|
||||
list List benchmark tests command
|
||||
```
|
||||
|
||||
The benchmark has been split into different categories of skills you can test your agent on. You can see what categories are available with
|
||||
```bash
|
||||
./run benchmark categories list
|
||||
# And what tests are available with
|
||||
./run benchmark tests list
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
Finally, you can run the benchmark with
|
||||
|
||||
```bash
|
||||
./run benchmark start YOUR_AGENT_NAME
|
||||
|
||||
```
|
||||
|
||||
>
|
||||
@@ -4,7 +4,7 @@ AutoGPT Classic was an experimental project to demonstrate autonomous GPT-4 oper
|
||||
|
||||
## Project Status
|
||||
|
||||
⚠️ **This project is unsupported, and dependencies will not be updated. It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform)**
|
||||
**This project is unsupported, and dependencies will not be updated.** It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform).
|
||||
|
||||
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
||||
|
||||
@@ -16,37 +16,171 @@ AutoGPT Classic was one of the first implementations of autonomous AI agents - A
|
||||
- Learn from the results and adjust its approach
|
||||
- Chain multiple actions together to achieve an objective
|
||||
|
||||
## Key Features
|
||||
|
||||
- 🔄 Autonomous task chaining
|
||||
- 🛠 Tool and API integration capabilities
|
||||
- 💾 Memory management for context retention
|
||||
- 🔍 Web browsing and information gathering
|
||||
- 📝 File operations and content creation
|
||||
- 🔄 Self-prompting and task breakdown
|
||||
|
||||
## Structure
|
||||
|
||||
The project is organized into several key components:
|
||||
- `/benchmark` - Performance testing tools
|
||||
- `/forge` - Core autonomous agent framework
|
||||
- `/frontend` - User interface components
|
||||
- `/original_autogpt` - Original implementation
|
||||
```
|
||||
classic/
|
||||
├── pyproject.toml # Single consolidated Poetry project
|
||||
├── poetry.lock # Single lock file
|
||||
├── forge/ # Core autonomous agent framework
|
||||
├── original_autogpt/ # Original implementation
|
||||
├── direct_benchmark/ # Benchmark harness
|
||||
└── benchmark/ # Challenge definitions (data)
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
While this project is no longer actively maintained, you can still explore the codebase:
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- [Poetry](https://python-poetry.org/docs/#installation)
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||
cd classic
|
||||
|
||||
# Install everything
|
||||
poetry install
|
||||
```
|
||||
|
||||
2. Review the documentation:
|
||||
- For reference, see the [documentation](https://docs.agpt.co). You can browse at the same point in time as this commit so the docs don't change.
|
||||
- Check `CLI-USAGE.md` for command-line interface details
|
||||
- Refer to `TROUBLESHOOTING.md` for common issues
|
||||
### Configuration
|
||||
|
||||
Configuration uses a layered system:
|
||||
|
||||
1. **Environment variables** (`.env` file)
|
||||
2. **Workspace settings** (`.autogpt/autogpt.yaml`)
|
||||
3. **Agent settings** (`.autogpt/agents/{id}/permissions.yaml`)
|
||||
|
||||
Copy the example environment file and add your API keys:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Key environment variables:
|
||||
```bash
|
||||
# Required
|
||||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# Optional LLM settings
|
||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
||||
|
||||
# Optional search providers
|
||||
TAVILY_API_KEY=tvly-...
|
||||
SERPER_API_KEY=...
|
||||
|
||||
# Optional infrastructure
|
||||
LOG_LEVEL=DEBUG
|
||||
PORT=8000
|
||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
||||
```
|
||||
|
||||
### Running
|
||||
|
||||
All commands run from the `classic/` directory:
|
||||
|
||||
```bash
|
||||
# Run forge agent
|
||||
poetry run python -m forge
|
||||
|
||||
# Run original autogpt server
|
||||
poetry run serve --debug
|
||||
|
||||
# Run autogpt CLI
|
||||
poetry run autogpt
|
||||
```
|
||||
|
||||
Agents run on `http://localhost:8000` by default.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
```bash
|
||||
poetry run direct-benchmark run
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
poetry run pytest # All tests
|
||||
poetry run pytest forge/tests/ # Forge tests only
|
||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
||||
```
|
||||
|
||||
## Workspaces
|
||||
|
||||
Agents operate within a **workspace** directory that contains all agent data and files:
|
||||
|
||||
```
|
||||
{workspace}/
|
||||
├── .autogpt/
|
||||
│ ├── autogpt.yaml # Workspace-level permissions
|
||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
||||
│ └── agents/
|
||||
│ └── AutoGPT-{agent_id}/
|
||||
│ ├── state.json # Agent profile, directives, history
|
||||
│ ├── permissions.yaml # Agent-specific permissions
|
||||
│ └── workspace/ # Agent's sandboxed working directory
|
||||
```
|
||||
|
||||
- The workspace defaults to the current working directory
|
||||
- Multiple agents can coexist in the same workspace
|
||||
- Agent file access is sandboxed to their `workspace/` subdirectory
|
||||
- State persists across sessions via `state.json`
|
||||
|
||||
## Permissions
|
||||
|
||||
AutoGPT uses a **layered permission system** with pattern matching:
|
||||
|
||||
### Permission Files
|
||||
|
||||
| File | Scope | Location |
|
||||
|------|-------|----------|
|
||||
| `autogpt.yaml` | All agents in workspace | `.autogpt/autogpt.yaml` |
|
||||
| `permissions.yaml` | Single agent | `.autogpt/agents/{id}/permissions.yaml` |
|
||||
|
||||
### Permission Format
|
||||
|
||||
```yaml
|
||||
allow:
|
||||
- read_file({workspace}/**) # Read any file in workspace
|
||||
- write_to_file({workspace}/**) # Write any file in workspace
|
||||
- web_search(*) # All web searches
|
||||
|
||||
deny:
|
||||
- read_file(**.env) # Block .env files
|
||||
- execute_shell(sudo:*) # Block sudo commands
|
||||
```
|
||||
|
||||
### Check Order (First Match Wins)
|
||||
|
||||
1. Agent deny → Block
|
||||
2. Workspace deny → Block
|
||||
3. Agent allow → Allow
|
||||
4. Workspace allow → Allow
|
||||
5. Prompt user → Interactive approval
|
||||
|
||||
### Interactive Approval
|
||||
|
||||
When prompted, users can approve commands with different scopes:
|
||||
- **Once** - Allow this one time only
|
||||
- **Agent** - Always allow for this agent
|
||||
- **Workspace** - Always allow for all agents
|
||||
- **Deny** - Block this command
|
||||
|
||||
### Default Security
|
||||
|
||||
Denied by default:
|
||||
- Sensitive files (`.env`, `.key`, `.pem`)
|
||||
- Destructive commands (`rm -rf`, `sudo`)
|
||||
- Operations outside the workspace
|
||||
|
||||
## Security Notice
|
||||
|
||||
This codebase has **known vulnerabilities** and issues with its dependencies. It will not be updated to new dependencies. Use for educational purposes only.
|
||||
|
||||
## License
|
||||
|
||||
@@ -55,27 +189,3 @@ This project segment is licensed under the MIT License - see the [LICENSE](LICEN
|
||||
## Documentation
|
||||
|
||||
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
||||
You can browse at the same point in time as this commit so the docs don't change.
|
||||
|
||||
## Historical Impact
|
||||
|
||||
AutoGPT Classic played a significant role in advancing the field of autonomous AI agents:
|
||||
- Demonstrated practical implementation of AI autonomy
|
||||
- Inspired numerous derivative projects and research
|
||||
- Contributed to the development of AI agent architectures
|
||||
- Helped identify key challenges in AI autonomy
|
||||
|
||||
## Security Notice
|
||||
|
||||
If you're studying this codebase, please understand this has KNOWN vulnerabilities and issues with its dependencies. It will not be updated to new dependencies.
|
||||
|
||||
## Community & Support
|
||||
|
||||
While active development has concluded:
|
||||
- The codebase remains available for study and reference
|
||||
- Historical discussions can be found in project issues
|
||||
- Related research and developments continue in the broader AI agent community
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Thanks to all contributors who participated in this experimental project and helped advance the field of autonomous AI agents.
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
AGENT_NAME=mini-agi
|
||||
REPORTS_FOLDER="reports/mini-agi"
|
||||
OPENAI_API_KEY="sk-" # for LLM eval
|
||||
BUILD_SKILL_TREE=false # set to true to build the skill tree.
|
||||
@@ -1,12 +0,0 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
# Ignore rules that conflict with Black code style
|
||||
extend-ignore = E203, W503
|
||||
exclude =
|
||||
__pycache__/,
|
||||
*.pyc,
|
||||
.pytest_cache/,
|
||||
venv*/,
|
||||
.venv/,
|
||||
reports/,
|
||||
agbenchmark/reports/,
|
||||
174
classic/benchmark/.gitignore
vendored
174
classic/benchmark/.gitignore
vendored
@@ -1,174 +0,0 @@
|
||||
agbenchmark_config/workspace/
|
||||
backend/backend_stdout.txt
|
||||
reports/df*.pkl
|
||||
reports/raw*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
.DS_Store
|
||||
```
|
||||
secrets.json
|
||||
agbenchmark_config/challenges_already_beaten.json
|
||||
agbenchmark_config/challenges/pri_*
|
||||
agbenchmark_config/updates.json
|
||||
agbenchmark_config/reports/*
|
||||
agbenchmark_config/reports/success_rate.json
|
||||
agbenchmark_config/reports/regression_tests.json
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 AutoGPT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,25 +0,0 @@
|
||||
# Auto-GPT Benchmarks
|
||||
|
||||
Built for the purpose of benchmarking the performance of agents regardless of how they work.
|
||||
|
||||
Objectively know how well your agent is performing in categories like code, retrieval, memory, and safety.
|
||||
|
||||
Save time and money while doing it through smart dependencies. The best part? It's all automated.
|
||||
|
||||
## Scores:
|
||||
|
||||
<img width="733" alt="Screenshot 2023-07-25 at 10 35 01 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/98963e0b-18b9-4b17-9a6a-4d3e4418af70">
|
||||
|
||||
## Ranking overall:
|
||||
|
||||
- 1- [Beebot](https://github.com/AutoPackAI/beebot)
|
||||
- 2- [mini-agi](https://github.com/muellerberndt/mini-agi)
|
||||
- 3- [Auto-GPT](https://github.com/Significant-Gravitas/AutoGPT)
|
||||
|
||||
## Detailed results:
|
||||
|
||||
<img width="733" alt="Screenshot 2023-07-25 at 10 42 15 AM" src="https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/assets/9652976/39be464c-c842-4437-b28a-07d878542a83">
|
||||
|
||||
[Click here to see the results and the raw data!](https://docs.google.com/spreadsheets/d/1WXm16P2AHNbKpkOI0LYBpcsGG0O7D8HYTG5Uj0PaJjA/edit#gid=203558751)!
|
||||
|
||||
More agents coming soon !
|
||||
@@ -1,69 +0,0 @@
|
||||
## As a user
|
||||
|
||||
1. `pip install auto-gpt-benchmarks`
|
||||
2. Add boilerplate code to run and kill agent
|
||||
3. `agbenchmark`
|
||||
- `--category challenge_category` to run tests in a specific category
|
||||
- `--mock` to only run mock tests if they exists for each test
|
||||
- `--noreg` to skip any tests that have passed in the past. When you run without this flag and a previous challenge that passed fails, it will now not be regression tests
|
||||
4. We call boilerplate code for your agent
|
||||
5. Show pass rate of tests, logs, and any other metrics
|
||||
|
||||
## Contributing
|
||||
|
||||
##### Diagrams: https://whimsical.com/agbenchmark-5n4hXBq1ZGzBwRsK4TVY7x
|
||||
|
||||
### To run the existing mocks
|
||||
|
||||
1. clone the repo `auto-gpt-benchmarks`
|
||||
2. `pip install poetry`
|
||||
3. `poetry shell`
|
||||
4. `poetry install`
|
||||
5. `cp .env_example .env`
|
||||
6. `git submodule update --init --remote --recursive`
|
||||
7. `uvicorn server:app --reload`
|
||||
8. `agbenchmark --mock`
|
||||
Keep config the same and watch the logs :)
|
||||
|
||||
### To run with mini-agi
|
||||
|
||||
1. Navigate to `auto-gpt-benchmarks/agent/mini-agi`
|
||||
2. `pip install -r requirements.txt`
|
||||
3. `cp .env_example .env`, set `PROMPT_USER=false` and add your `OPENAI_API_KEY=`. Sset `MODEL="gpt-3.5-turbo"` if you don't have access to `gpt-4` yet. Also make sure you have Python 3.10^ installed
|
||||
4. set `AGENT_NAME=mini-agi` in `.env` file and where you want your `REPORTS_FOLDER` to be
|
||||
5. Make sure to follow the commands above, and remove mock flag `agbenchmark`
|
||||
|
||||
- To add requirements `poetry add requirement`.
|
||||
|
||||
Feel free to create prs to merge with `main` at will (but also feel free to ask for review) - if you can't send msg in R&D chat for access.
|
||||
|
||||
If you push at any point and break things - it'll happen to everyone - fix it asap. Step 1 is to revert `master` to last working commit
|
||||
|
||||
Let people know what beautiful code you write does, document everything well
|
||||
|
||||
Share your progress :)
|
||||
|
||||
#### Dataset
|
||||
|
||||
Manually created, existing challenges within Auto-Gpt, https://osu-nlp-group.github.io/Mind2Web/
|
||||
|
||||
## How do I add new agents to agbenchmark ?
|
||||
|
||||
Example with smol developer.
|
||||
|
||||
1- Create a github branch with your agent following the same pattern as this example:
|
||||
|
||||
https://github.com/smol-ai/developer/pull/114/files
|
||||
|
||||
2- Create the submodule and the github workflow by following the same pattern as this example:
|
||||
|
||||
https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks/pull/48/files
|
||||
|
||||
## How do I run agent in different environments?
|
||||
|
||||
**To just use as the benchmark for your agent**. `pip install` the package and run `agbenchmark`
|
||||
|
||||
**For internal Auto-GPT ci runs**, specify the `AGENT_NAME` you want you use and set the `HOME_ENV`.
|
||||
Ex. `AGENT_NAME=mini-agi`
|
||||
|
||||
**To develop agent alongside benchmark**, you can specify the `AGENT_NAME` you want you use and add as a submodule to the repo
|
||||
@@ -1,352 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
from click_default_group import DefaultGroup
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.logging import configure_logging
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# try:
|
||||
# if os.getenv("HELICONE_API_KEY"):
|
||||
# import helicone # noqa
|
||||
|
||||
# helicone_enabled = True
|
||||
# else:
|
||||
# helicone_enabled = False
|
||||
# except ImportError:
|
||||
# helicone_enabled = False
|
||||
|
||||
|
||||
class InvalidInvocationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BENCHMARK_START_TIME_DT = datetime.now(timezone.utc)
|
||||
BENCHMARK_START_TIME = BENCHMARK_START_TIME_DT.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||
|
||||
|
||||
# if helicone_enabled:
|
||||
# from helicone.lock import HeliconeLockManager
|
||||
|
||||
# HeliconeLockManager.write_custom_property(
|
||||
# "benchmark_start_time", BENCHMARK_START_TIME
|
||||
# )
|
||||
|
||||
|
||||
@click.group(cls=DefaultGroup, default_if_no_args=True)
|
||||
@click.option("--debug", is_flag=True, help="Enable debug output")
|
||||
def cli(
|
||||
debug: bool,
|
||||
) -> Any:
|
||||
configure_logging(logging.DEBUG if debug else logging.INFO)
|
||||
|
||||
|
||||
@cli.command(hidden=True)
|
||||
def start():
|
||||
raise DeprecationWarning(
|
||||
"`agbenchmark start` is deprecated. Use `agbenchmark run` instead."
|
||||
)
|
||||
|
||||
|
||||
@cli.command(default=True)
|
||||
@click.option(
|
||||
"-N", "--attempts", default=1, help="Number of times to run each challenge."
|
||||
)
|
||||
@click.option(
|
||||
"-c",
|
||||
"--category",
|
||||
multiple=True,
|
||||
help="(+) Select a category to run.",
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--skip-category",
|
||||
multiple=True,
|
||||
help="(+) Exclude a category from running.",
|
||||
)
|
||||
@click.option("--test", multiple=True, help="(+) Select a test to run.")
|
||||
@click.option("--maintain", is_flag=True, help="Run only regression tests.")
|
||||
@click.option("--improve", is_flag=True, help="Run only non-regression tests.")
|
||||
@click.option(
|
||||
"--explore",
|
||||
is_flag=True,
|
||||
help="Run only challenges that have never been beaten.",
|
||||
)
|
||||
@click.option(
|
||||
"--no-dep",
|
||||
is_flag=True,
|
||||
help="Run all (selected) challenges, regardless of dependency success/failure.",
|
||||
)
|
||||
@click.option("--cutoff", type=int, help="Override the challenge time limit (seconds).")
|
||||
@click.option("--nc", is_flag=True, help="Disable the challenge time limit.")
|
||||
@click.option("--mock", is_flag=True, help="Run with mock")
|
||||
@click.option("--keep-answers", is_flag=True, help="Keep answers")
|
||||
@click.option(
|
||||
"--backend",
|
||||
is_flag=True,
|
||||
help="Write log output to a file instead of the terminal.",
|
||||
)
|
||||
# @click.argument(
|
||||
# "agent_path",
|
||||
# type=click.Path(exists=True, file_okay=False, path_type=Path),
|
||||
# required=False,
|
||||
# )
|
||||
def run(
|
||||
maintain: bool,
|
||||
improve: bool,
|
||||
explore: bool,
|
||||
mock: bool,
|
||||
no_dep: bool,
|
||||
nc: bool,
|
||||
keep_answers: bool,
|
||||
test: tuple[str],
|
||||
category: tuple[str],
|
||||
skip_category: tuple[str],
|
||||
attempts: int,
|
||||
cutoff: Optional[int] = None,
|
||||
backend: Optional[bool] = False,
|
||||
# agent_path: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the benchmark on the agent in the current directory.
|
||||
|
||||
Options marked with (+) can be specified multiple times, to select multiple items.
|
||||
"""
|
||||
from agbenchmark.main import run_benchmark, validate_args
|
||||
|
||||
agbenchmark_config = AgentBenchmarkConfig.load()
|
||||
logger.debug(f"agbenchmark_config: {agbenchmark_config.agbenchmark_config_dir}")
|
||||
try:
|
||||
validate_args(
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
no_cutoff=nc,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
except InvalidInvocationError as e:
|
||||
logger.error("Error: " + "\n".join(e.args))
|
||||
sys.exit(1)
|
||||
|
||||
original_stdout = sys.stdout # Save the original standard output
|
||||
exit_code = None
|
||||
|
||||
if backend:
|
||||
with open("backend/backend_stdout.txt", "w") as f:
|
||||
sys.stdout = f
|
||||
exit_code = run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
mock=mock,
|
||||
no_dep=no_dep,
|
||||
no_cutoff=nc,
|
||||
keep_answers=keep_answers,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
attempts_per_challenge=attempts,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
|
||||
sys.stdout = original_stdout
|
||||
|
||||
else:
|
||||
exit_code = run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
maintain=maintain,
|
||||
improve=improve,
|
||||
explore=explore,
|
||||
mock=mock,
|
||||
no_dep=no_dep,
|
||||
no_cutoff=nc,
|
||||
keep_answers=keep_answers,
|
||||
tests=test,
|
||||
categories=category,
|
||||
skip_categories=skip_category,
|
||||
attempts_per_challenge=attempts,
|
||||
cutoff=cutoff,
|
||||
)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--port", type=int, help="Port to run the API on.")
|
||||
def serve(port: Optional[int] = None):
|
||||
"""Serve the benchmark frontend and API on port 8080."""
|
||||
import uvicorn
|
||||
|
||||
from agbenchmark.app import setup_fastapi_app
|
||||
|
||||
config = AgentBenchmarkConfig.load()
|
||||
app = setup_fastapi_app(config)
|
||||
|
||||
# Run the FastAPI application using uvicorn
|
||||
port = port or int(os.getenv("PORT", 8080))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def config():
|
||||
"""Displays info regarding the present AGBenchmark config."""
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
try:
|
||||
config = AgentBenchmarkConfig.load()
|
||||
except FileNotFoundError as e:
|
||||
click.echo(e, err=True)
|
||||
return 1
|
||||
|
||||
pretty_print_model(config, include_header=False)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def challenge():
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@challenge.command("list")
|
||||
@click.option(
|
||||
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
|
||||
)
|
||||
@click.option(
|
||||
"--names", "only_names", is_flag=True, help="List only the challenge names."
|
||||
)
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
|
||||
"""Lists [available|all] challenges."""
|
||||
import json
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.data_types import Category, DifficultyLevel
|
||||
from .utils.utils import sorted_by_enum_index
|
||||
|
||||
DIFFICULTY_COLORS = {
|
||||
difficulty: color
|
||||
for difficulty, color in zip(
|
||||
DifficultyLevel,
|
||||
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
|
||||
)
|
||||
}
|
||||
CATEGORY_COLORS = {
|
||||
category: f"bright_{color}"
|
||||
for category, color in zip(
|
||||
Category,
|
||||
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
|
||||
)
|
||||
}
|
||||
|
||||
# Load challenges
|
||||
challenges = filter(
|
||||
lambda c: c.info.available or include_unavailable,
|
||||
[
|
||||
*load_builtin_challenges(),
|
||||
*load_webarena_challenges(skip_unavailable=False),
|
||||
],
|
||||
)
|
||||
challenges = sorted_by_enum_index(
|
||||
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
|
||||
)
|
||||
|
||||
if only_names:
|
||||
if output_json:
|
||||
click.echo(json.dumps([c.info.name for c in challenges]))
|
||||
return
|
||||
|
||||
for c in challenges:
|
||||
click.echo(
|
||||
click.style(c.info.name, fg=None if c.info.available else "black")
|
||||
)
|
||||
return
|
||||
|
||||
if output_json:
|
||||
click.echo(
|
||||
json.dumps([json.loads(c.info.model_dump_json()) for c in challenges])
|
||||
)
|
||||
return
|
||||
|
||||
headers = tuple(
|
||||
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
|
||||
)
|
||||
table = [
|
||||
tuple(
|
||||
v if challenge.info.available else click.style(v, fg="black")
|
||||
for v in (
|
||||
challenge.info.name,
|
||||
(
|
||||
click.style(
|
||||
challenge.info.difficulty.value,
|
||||
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
|
||||
)
|
||||
if challenge.info.difficulty
|
||||
else click.style("-", fg="black")
|
||||
),
|
||||
" ".join(
|
||||
click.style(cat.value, fg=CATEGORY_COLORS[cat])
|
||||
for cat in sorted_by_enum_index(challenge.info.category, Category)
|
||||
),
|
||||
)
|
||||
)
|
||||
for challenge in challenges
|
||||
]
|
||||
click.echo(tabulate(table, headers=headers))
|
||||
|
||||
|
||||
@challenge.command()
|
||||
@click.option("--json", is_flag=True)
|
||||
@click.argument("name")
|
||||
def info(name: str, json: bool):
|
||||
from itertools import chain
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
for challenge in chain(
|
||||
load_builtin_challenges(),
|
||||
load_webarena_challenges(skip_unavailable=False),
|
||||
):
|
||||
if challenge.info.name != name:
|
||||
continue
|
||||
|
||||
if json:
|
||||
click.echo(challenge.info.model_dump_json())
|
||||
break
|
||||
|
||||
pretty_print_model(challenge.info)
|
||||
break
|
||||
else:
|
||||
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def version():
|
||||
"""Print version info for the AGBenchmark application."""
|
||||
import toml
|
||||
|
||||
package_root = Path(__file__).resolve().parent.parent
|
||||
pyproject = toml.load(package_root / "pyproject.toml")
|
||||
version = pyproject["tool"]["poetry"]["version"]
|
||||
click.echo(f"AGBenchmark version {version}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -1,111 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from agent_protocol_client import (
|
||||
AgentApi,
|
||||
ApiClient,
|
||||
Configuration,
|
||||
Step,
|
||||
TaskRequestBody,
|
||||
)
|
||||
|
||||
from agbenchmark.agent_interface import get_list_of_file_paths
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_api_agent(
|
||||
task: str,
|
||||
config: AgentBenchmarkConfig,
|
||||
timeout: int,
|
||||
artifacts_location: Optional[Path] = None,
|
||||
*,
|
||||
mock: bool = False,
|
||||
) -> AsyncIterator[Step]:
|
||||
configuration = Configuration(host=config.host)
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_request_body = TaskRequestBody(input=task, additional_input=None)
|
||||
|
||||
start_time = time.time()
|
||||
response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
task_id = response.task_id
|
||||
|
||||
if artifacts_location:
|
||||
logger.debug("Uploading task input artifacts to agent...")
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_in"
|
||||
)
|
||||
|
||||
logger.debug("Running agent until finished or timeout...")
|
||||
while True:
|
||||
step = await api_instance.execute_agent_task_step(task_id=task_id)
|
||||
yield step
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError("Time limit exceeded")
|
||||
if step and mock:
|
||||
step.is_last = True
|
||||
if not step or step.is_last:
|
||||
break
|
||||
|
||||
if artifacts_location:
|
||||
# In "mock" mode, we cheat by giving the correct artifacts to pass the test
|
||||
if mock:
|
||||
logger.debug("Uploading mock artifacts to agent...")
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_out"
|
||||
)
|
||||
|
||||
logger.debug("Downloading agent artifacts...")
|
||||
await download_agent_artifacts_into_folder(
|
||||
api_instance, task_id, config.temp_folder
|
||||
)
|
||||
|
||||
|
||||
async def download_agent_artifacts_into_folder(
|
||||
api_instance: AgentApi, task_id: str, folder: Path
|
||||
):
|
||||
artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)
|
||||
|
||||
for artifact in artifacts.artifacts:
|
||||
# current absolute path of the directory of the file
|
||||
if artifact.relative_path:
|
||||
path: str = (
|
||||
artifact.relative_path
|
||||
if not artifact.relative_path.startswith("/")
|
||||
else artifact.relative_path[1:]
|
||||
)
|
||||
folder = (folder / path).parent
|
||||
|
||||
if not folder.exists():
|
||||
folder.mkdir(parents=True)
|
||||
|
||||
file_path = folder / artifact.file_name
|
||||
logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = await api_instance.download_agent_task_artifact(
|
||||
task_id=task_id, artifact_id=artifact.artifact_id
|
||||
)
|
||||
|
||||
f.write(content)
|
||||
|
||||
|
||||
async def upload_artifacts(
|
||||
api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
|
||||
) -> None:
|
||||
for file_path in get_list_of_file_paths(artifacts_location, type):
|
||||
relative_path: Optional[str] = "/".join(
|
||||
str(file_path).split(f"{type}/", 1)[-1].split("/")[:-1]
|
||||
)
|
||||
if not relative_path:
|
||||
relative_path = None
|
||||
|
||||
await api_instance.upload_agent_task_artifacts(
|
||||
task_id=task_id, file=str(file_path), relative_path=relative_path
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
HELICONE_GRAPHQL_LOGS = os.getenv("HELICONE_GRAPHQL_LOGS", "").lower() == "true"
|
||||
|
||||
|
||||
def get_list_of_file_paths(
|
||||
challenge_dir_path: str | Path, artifact_folder_name: str
|
||||
) -> list[Path]:
|
||||
source_dir = Path(challenge_dir_path) / artifact_folder_name
|
||||
if not source_dir.exists():
|
||||
return []
|
||||
return list(source_dir.iterdir())
|
||||
|
||||
|
||||
def copy_challenge_artifacts_into_workspace(
|
||||
challenge_dir_path: str | Path, artifact_folder_name: str, workspace: str | Path
|
||||
) -> None:
|
||||
file_paths = get_list_of_file_paths(challenge_dir_path, artifact_folder_name)
|
||||
for file_path in file_paths:
|
||||
if file_path.is_file():
|
||||
shutil.copy(file_path, workspace)
|
||||
@@ -1,339 +0,0 @@
|
||||
import datetime
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import psutil
|
||||
from agent_protocol_client import AgentApi, ApiClient, ApiException, Configuration
|
||||
from agent_protocol_client.models import Task, TaskRequestBody
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
from agbenchmark.challenges import ChallengeInfo
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types_v2 import (
|
||||
BenchmarkRun,
|
||||
Metrics,
|
||||
RepositoryInfo,
|
||||
RunDetails,
|
||||
TaskInfo,
|
||||
)
|
||||
from agbenchmark.schema import TaskEvalRequestBody
|
||||
from agbenchmark.utils.utils import write_pretty_json
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHALLENGES: dict[str, ChallengeInfo] = {}
|
||||
challenges_path = Path(__file__).parent / "challenges"
|
||||
challenge_spec_files = deque(
|
||||
glob.glob(
|
||||
f"{challenges_path}/**/data.json",
|
||||
recursive=True,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Loading challenges...")
|
||||
while challenge_spec_files:
|
||||
challenge_spec_file = Path(challenge_spec_files.popleft())
|
||||
challenge_relpath = challenge_spec_file.relative_to(challenges_path.parent)
|
||||
if challenge_relpath.is_relative_to("challenges/deprecated"):
|
||||
continue
|
||||
|
||||
logger.debug(f"Loading {challenge_relpath}...")
|
||||
try:
|
||||
challenge_info = ChallengeInfo.model_validate_json(
|
||||
challenge_spec_file.read_text()
|
||||
)
|
||||
except ValidationError as e:
|
||||
if logging.getLogger().level == logging.DEBUG:
|
||||
logger.warning(f"Spec file {challenge_relpath} failed to load:\n{e}")
|
||||
logger.debug(f"Invalid challenge spec: {challenge_spec_file.read_text()}")
|
||||
continue
|
||||
|
||||
if not challenge_info.eval_id:
|
||||
challenge_info.eval_id = str(uuid.uuid4())
|
||||
# this will sort all the keys of the JSON systematically
|
||||
# so that the order is always the same
|
||||
write_pretty_json(challenge_info.model_dump(), challenge_spec_file)
|
||||
|
||||
CHALLENGES[challenge_info.eval_id] = challenge_info
|
||||
|
||||
|
||||
class BenchmarkTaskInfo(BaseModel):
|
||||
task_id: str
|
||||
start_time: datetime.datetime
|
||||
challenge_info: ChallengeInfo
|
||||
|
||||
|
||||
task_informations: dict[str, BenchmarkTaskInfo] = {}
|
||||
|
||||
|
||||
def find_agbenchmark_without_uvicorn():
|
||||
pids = []
|
||||
for process in psutil.process_iter(
|
||||
attrs=[
|
||||
"pid",
|
||||
"cmdline",
|
||||
"name",
|
||||
"username",
|
||||
"status",
|
||||
"cpu_percent",
|
||||
"memory_info",
|
||||
"create_time",
|
||||
"cwd",
|
||||
"connections",
|
||||
]
|
||||
):
|
||||
try:
|
||||
# Convert the process.info dictionary values to strings and concatenate them
|
||||
full_info = " ".join([str(v) for k, v in process.as_dict().items()])
|
||||
|
||||
if "agbenchmark" in full_info and "uvicorn" not in full_info:
|
||||
pids.append(process.pid)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
return pids
|
||||
|
||||
|
||||
class CreateReportRequest(BaseModel):
|
||||
test: str
|
||||
test_run_id: str
|
||||
# category: Optional[str] = []
|
||||
mock: Optional[bool] = False
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
updates_list = []
|
||||
|
||||
origins = [
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:5000",
|
||||
"http://localhost:5000",
|
||||
]
|
||||
|
||||
|
||||
def stream_output(pipe):
|
||||
for line in pipe:
|
||||
print(line, end="")
|
||||
|
||||
|
||||
def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
from agbenchmark.agent_api_interface import upload_artifacts
|
||||
from agbenchmark.challenges import get_challenge_from_source_uri
|
||||
from agbenchmark.main import run_benchmark
|
||||
|
||||
configuration = Configuration(
|
||||
host=agbenchmark_config.host or "http://localhost:8000"
|
||||
)
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/reports")
|
||||
def run_single_test(body: CreateReportRequest) -> dict:
|
||||
pids = find_agbenchmark_without_uvicorn()
|
||||
logger.info(f"pids already running with agbenchmark: {pids}")
|
||||
|
||||
logger.debug(f"Request to /reports: {body.model_dump()}")
|
||||
|
||||
# Start the benchmark in a separate thread
|
||||
benchmark_process = Process(
|
||||
target=lambda: run_benchmark(
|
||||
config=agbenchmark_config,
|
||||
tests=(body.test,),
|
||||
mock=body.mock or False,
|
||||
)
|
||||
)
|
||||
benchmark_process.start()
|
||||
|
||||
# Wait for the benchmark to finish, with a timeout of 200 seconds
|
||||
timeout = 200
|
||||
start_time = time.time()
|
||||
while benchmark_process.is_alive():
|
||||
if time.time() - start_time > timeout:
|
||||
logger.warning(f"Benchmark run timed out after {timeout} seconds")
|
||||
benchmark_process.terminate()
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
logger.debug(f"Benchmark finished running in {time.time() - start_time} s")
|
||||
|
||||
# List all folders in the current working directory
|
||||
reports_folder = agbenchmark_config.reports_folder
|
||||
folders = [folder for folder in reports_folder.iterdir() if folder.is_dir()]
|
||||
|
||||
# Sort the folders based on their names
|
||||
sorted_folders = sorted(folders, key=lambda x: x.name)
|
||||
|
||||
# Get the last folder
|
||||
latest_folder = sorted_folders[-1] if sorted_folders else None
|
||||
|
||||
# Read report.json from this folder
|
||||
if latest_folder:
|
||||
report_path = latest_folder / "report.json"
|
||||
logger.debug(f"Getting latest report from {report_path}")
|
||||
if report_path.exists():
|
||||
with report_path.open() as file:
|
||||
data = json.load(file)
|
||||
logger.debug(f"Report data: {data}")
|
||||
else:
|
||||
raise HTTPException(
|
||||
502,
|
||||
"Could not get result after running benchmark: "
|
||||
f"'report.json' does not exist in '{latest_folder}'",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
504, "Could not get result after running benchmark: no reports found"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@router.post("/agent/tasks", tags=["agent"])
|
||||
async def create_agent_task(task_eval_request: TaskEvalRequestBody) -> Task:
|
||||
"""
|
||||
Creates a new task using the provided TaskEvalRequestBody and returns a Task.
|
||||
|
||||
Args:
|
||||
task_eval_request: `TaskRequestBody` including an eval_id.
|
||||
|
||||
Returns:
|
||||
Task: A new task with task_id, input, additional_input,
|
||||
and empty lists for artifacts and steps.
|
||||
|
||||
Example:
|
||||
Request (TaskEvalRequestBody defined in schema.py):
|
||||
{
|
||||
...,
|
||||
"eval_id": "50da533e-3904-4401-8a07-c49adf88b5eb"
|
||||
}
|
||||
|
||||
Response (Task defined in `agent_protocol_client.models`):
|
||||
{
|
||||
"task_id": "50da533e-3904-4401-8a07-c49adf88b5eb",
|
||||
"input": "Write the word 'Washington' to a .txt file",
|
||||
"artifacts": []
|
||||
}
|
||||
"""
|
||||
try:
|
||||
challenge_info = CHALLENGES[task_eval_request.eval_id]
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_input = challenge_info.task
|
||||
|
||||
task_request_body = TaskRequestBody(
|
||||
input=task_input, additional_input=None
|
||||
)
|
||||
task_response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
task_info = BenchmarkTaskInfo(
|
||||
task_id=task_response.task_id,
|
||||
start_time=datetime.datetime.now(datetime.timezone.utc),
|
||||
challenge_info=challenge_info,
|
||||
)
|
||||
task_informations[task_info.task_id] = task_info
|
||||
|
||||
if input_artifacts_dir := challenge_info.task_artifacts_dir:
|
||||
await upload_artifacts(
|
||||
api_instance,
|
||||
input_artifacts_dir,
|
||||
task_response.task_id,
|
||||
"artifacts_in",
|
||||
)
|
||||
return task_response
|
||||
except ApiException as e:
|
||||
logger.error(f"Error whilst trying to create a task:\n{e}")
|
||||
logger.error(
|
||||
"The above error was caused while processing request: "
|
||||
f"{task_eval_request}"
|
||||
)
|
||||
raise HTTPException(500)
|
||||
|
||||
@router.post("/agent/tasks/{task_id}/steps")
|
||||
async def proxy(request: Request, task_id: str):
|
||||
timeout = httpx.Timeout(300.0, read=300.0) # 5 minutes
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
# Construct the new URL
|
||||
new_url = f"{configuration.host}/ap/v1/agent/tasks/{task_id}/steps"
|
||||
|
||||
# Forward the request
|
||||
response = await client.post(
|
||||
new_url,
|
||||
content=await request.body(),
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
# Return the response from the forwarded request
|
||||
return Response(content=response.content, status_code=response.status_code)
|
||||
|
||||
@router.post("/agent/tasks/{task_id}/evaluations")
|
||||
async def create_evaluation(task_id: str) -> BenchmarkRun:
|
||||
task_info = task_informations[task_id]
|
||||
challenge = get_challenge_from_source_uri(task_info.challenge_info.source_uri)
|
||||
try:
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
eval_results = await challenge.evaluate_task_state(
|
||||
api_instance, task_id
|
||||
)
|
||||
|
||||
eval_info = BenchmarkRun(
|
||||
repository_info=RepositoryInfo(),
|
||||
run_details=RunDetails(
|
||||
command=f"agbenchmark --test={challenge.info.name}",
|
||||
benchmark_start_time=(
|
||||
task_info.start_time.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||
),
|
||||
test_name=challenge.info.name,
|
||||
),
|
||||
task_info=TaskInfo(
|
||||
data_path=challenge.info.source_uri,
|
||||
is_regression=None,
|
||||
category=[c.value for c in challenge.info.category],
|
||||
task=challenge.info.task,
|
||||
answer=challenge.info.reference_answer or "",
|
||||
description=challenge.info.description or "",
|
||||
),
|
||||
metrics=Metrics(
|
||||
success=all(e.passed for e in eval_results),
|
||||
success_percentage=(
|
||||
100 * sum(e.score for e in eval_results) / len(eval_results)
|
||||
if eval_results # avoid division by 0
|
||||
else 0
|
||||
),
|
||||
attempted=True,
|
||||
),
|
||||
config={},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Returning evaluation data:\n{eval_info.model_dump_json(indent=4)}"
|
||||
)
|
||||
return eval_info
|
||||
except ApiException as e:
|
||||
logger.error(f"Error {e} whilst trying to evaluate task: {task_id}")
|
||||
raise HTTPException(500)
|
||||
|
||||
app.include_router(router, prefix="/ap/v1")
|
||||
|
||||
return app
|
||||
@@ -1,128 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
def _calculate_info_test_path(base_path: Path, benchmark_start_time: datetime) -> Path:
|
||||
"""
|
||||
Calculates the path to the directory where the test report will be saved.
|
||||
"""
|
||||
# Ensure the reports path exists
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get current UTC date-time stamp
|
||||
date_stamp = benchmark_start_time.strftime("%Y%m%dT%H%M%S")
|
||||
|
||||
# Default run name
|
||||
run_name = "full_run"
|
||||
|
||||
# Map command-line arguments to their respective labels
|
||||
arg_labels = {
|
||||
"--test": None,
|
||||
"--category": None,
|
||||
"--maintain": "maintain",
|
||||
"--improve": "improve",
|
||||
"--explore": "explore",
|
||||
}
|
||||
|
||||
# Identify the relevant command-line argument
|
||||
for arg, label in arg_labels.items():
|
||||
if arg in sys.argv:
|
||||
test_arg = sys.argv[sys.argv.index(arg) + 1] if label is None else None
|
||||
run_name = arg.strip("--")
|
||||
if test_arg:
|
||||
run_name = f"{run_name}_{test_arg}"
|
||||
break
|
||||
|
||||
# Create the full new directory path with ISO standard UTC date-time stamp
|
||||
report_path = base_path / f"{date_stamp}_{run_name}"
|
||||
|
||||
# Ensure the new directory is created
|
||||
# FIXME: this is not a desirable side-effect of loading the config
|
||||
report_path.mkdir(exist_ok=True)
|
||||
|
||||
return report_path
|
||||
|
||||
|
||||
class AgentBenchmarkConfig(BaseSettings, extra="allow"):
|
||||
"""
|
||||
Configuration model and loader for the AGBenchmark.
|
||||
|
||||
Projects that want to use AGBenchmark should contain an agbenchmark_config folder
|
||||
with a config.json file that - at minimum - specifies the `host` at which the
|
||||
subject application exposes an Agent Protocol compliant API.
|
||||
"""
|
||||
|
||||
agbenchmark_config_dir: Path = Field(exclude=True)
|
||||
"""Path to the agbenchmark_config folder of the subject agent application."""
|
||||
|
||||
categories: list[str] | None = None
|
||||
"""Categories to benchmark the agent for. If omitted, all categories are assumed."""
|
||||
|
||||
host: str
|
||||
"""Host (scheme://address:port) of the subject agent application."""
|
||||
|
||||
reports_folder: Path = Field(None)
|
||||
"""
|
||||
Path to the folder where new reports should be stored.
|
||||
Defaults to {agbenchmark_config_dir}/reports.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_dir: Optional[Path] = None) -> "AgentBenchmarkConfig":
|
||||
config_dir = config_dir or cls.find_config_folder()
|
||||
with (config_dir / "config.json").open("r") as f:
|
||||
return cls(
|
||||
agbenchmark_config_dir=config_dir,
|
||||
**json.load(f),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_config_folder(for_dir: Path = Path.cwd()) -> Path:
|
||||
"""
|
||||
Find the closest ancestor folder containing an agbenchmark_config folder,
|
||||
and returns the path of that agbenchmark_config folder.
|
||||
"""
|
||||
current_directory = for_dir
|
||||
while current_directory != Path("/"):
|
||||
if (path := current_directory / "agbenchmark_config").exists():
|
||||
if (path / "config.json").is_file():
|
||||
return path
|
||||
current_directory = current_directory.parent
|
||||
raise FileNotFoundError(
|
||||
"No 'agbenchmark_config' directory found in the path hierarchy."
|
||||
)
|
||||
|
||||
@property
|
||||
def config_file(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "config.json"
|
||||
|
||||
@field_validator("reports_folder", mode="before")
|
||||
def set_reports_folder(cls, value: Path, info: ValidationInfo):
|
||||
if not value:
|
||||
return info.data["agbenchmark_config_dir"] / "reports"
|
||||
return value
|
||||
|
||||
def get_report_dir(self, benchmark_start_time: datetime) -> Path:
|
||||
return _calculate_info_test_path(self.reports_folder, benchmark_start_time)
|
||||
|
||||
@property
|
||||
def regression_tests_file(self) -> Path:
|
||||
return self.reports_folder / "regression_tests.json"
|
||||
|
||||
@property
|
||||
def success_rate_file(self) -> Path:
|
||||
return self.reports_folder / "success_rate.json"
|
||||
|
||||
@property
|
||||
def challenges_already_beaten_file(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "challenges_already_beaten.json"
|
||||
|
||||
@property
|
||||
def temp_folder(self) -> Path:
|
||||
return self.agbenchmark_config_dir / "temp_folder"
|
||||
@@ -1,339 +0,0 @@
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from agbenchmark.challenges import OPTIONAL_CATEGORIES, BaseChallenge
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types import Test
|
||||
from agbenchmark.reports.ReportManager import RegressionTestsTracker
|
||||
from agbenchmark.reports.reports import (
|
||||
add_test_result_to_report,
|
||||
make_empty_test_report,
|
||||
session_finish,
|
||||
)
|
||||
from agbenchmark.utils.data_types import Category
|
||||
|
||||
GLOBAL_TIMEOUT = (
|
||||
1500 # The tests will stop after 25 minutes so we can send the reports.
|
||||
)
|
||||
|
||||
agbenchmark_config = AgentBenchmarkConfig.load()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
pytest_plugins = ["agbenchmark.utils.dependencies"]
|
||||
collect_ignore = ["challenges"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config() -> AgentBenchmarkConfig:
|
||||
return agbenchmark_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def temp_folder() -> Generator[Path, None, None]:
|
||||
"""
|
||||
Pytest fixture that sets up and tears down the temporary folder for each test.
|
||||
It is automatically used in every test due to the 'autouse=True' parameter.
|
||||
"""
|
||||
|
||||
# create output directory if it doesn't exist
|
||||
if not os.path.exists(agbenchmark_config.temp_folder):
|
||||
os.makedirs(agbenchmark_config.temp_folder, exist_ok=True)
|
||||
|
||||
yield agbenchmark_config.temp_folder
|
||||
# teardown after test function completes
|
||||
if not os.getenv("KEEP_TEMP_FOLDER_FILES"):
|
||||
for filename in os.listdir(agbenchmark_config.temp_folder):
|
||||
file_path = os.path.join(agbenchmark_config.temp_folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_path}. Reason: {e}")
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Pytest hook that adds command-line options to the `pytest` command.
|
||||
The added options are specific to agbenchmark and control its behavior:
|
||||
* `--mock` is used to run the tests in mock mode.
|
||||
* `--host` is used to specify the host for the tests.
|
||||
* `--category` is used to run only tests of a specific category.
|
||||
* `--nc` is used to run the tests without caching.
|
||||
* `--cutoff` is used to specify a cutoff time for the tests.
|
||||
* `--improve` is used to run only the tests that are marked for improvement.
|
||||
* `--maintain` is used to run only the tests that are marked for maintenance.
|
||||
* `--explore` is used to run the tests in exploration mode.
|
||||
* `--test` is used to run a specific test.
|
||||
* `--no-dep` is used to run the tests without dependencies.
|
||||
* `--keep-answers` is used to keep the answers of the tests.
|
||||
|
||||
Args:
|
||||
parser: The Pytest CLI parser to which the command-line options are added.
|
||||
"""
|
||||
parser.addoption("-N", "--attempts", action="store")
|
||||
parser.addoption("--no-dep", action="store_true")
|
||||
parser.addoption("--mock", action="store_true")
|
||||
parser.addoption("--host", default=None)
|
||||
parser.addoption("--nc", action="store_true")
|
||||
parser.addoption("--cutoff", action="store")
|
||||
parser.addoption("--category", action="append")
|
||||
parser.addoption("--test", action="append")
|
||||
parser.addoption("--improve", action="store_true")
|
||||
parser.addoption("--maintain", action="store_true")
|
||||
parser.addoption("--explore", action="store_true")
|
||||
parser.addoption("--keep-answers", action="store_true")
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
# Register category markers to prevent "unknown marker" warnings
|
||||
for category in Category:
|
||||
config.addinivalue_line("markers", f"{category.value}: {category}")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_regression(request: pytest.FixtureRequest) -> None:
|
||||
"""
|
||||
Fixture that checks for every test if it should be treated as a regression test,
|
||||
and whether to skip it based on that.
|
||||
|
||||
The test name is retrieved from the `request` object. Regression reports are loaded
|
||||
from the path specified in the benchmark configuration.
|
||||
|
||||
Effect:
|
||||
* If the `--improve` option is used and the current test is considered a regression
|
||||
test, it is skipped.
|
||||
* If the `--maintain` option is used and the current test is not considered a
|
||||
regression test, it is also skipped.
|
||||
|
||||
Args:
|
||||
request: The request object from which the test name and the benchmark
|
||||
configuration are retrieved.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
assert isinstance(request.node, pytest.Function)
|
||||
assert isinstance(request.node.parent, pytest.Class)
|
||||
test_name = request.node.parent.name
|
||||
challenge_location = getattr(request.node.cls, "CHALLENGE_LOCATION", "")
|
||||
skip_string = f"Skipping {test_name} at {challenge_location}"
|
||||
|
||||
# Check if the test name exists in the regression tests
|
||||
is_regression_test = rt_tracker.has_regression_test(test_name)
|
||||
if request.config.getoption("--improve") and is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's a regression test")
|
||||
elif request.config.getoption("--maintain") and not is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's not a regression test")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def mock(request: pytest.FixtureRequest) -> bool:
|
||||
"""
|
||||
Pytest fixture that retrieves the value of the `--mock` command-line option.
|
||||
The `--mock` option is used to run the tests in mock mode.
|
||||
|
||||
Args:
|
||||
request: The `pytest.FixtureRequest` from which the `--mock` option value
|
||||
is retrieved.
|
||||
|
||||
Returns:
|
||||
bool: Whether `--mock` is set for this session.
|
||||
"""
|
||||
mock = request.config.getoption("--mock")
|
||||
assert isinstance(mock, bool)
|
||||
return mock
|
||||
|
||||
|
||||
test_reports: dict[str, Test] = {}
|
||||
|
||||
|
||||
def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None:
|
||||
"""
|
||||
Pytest hook that is called when a test report is being generated.
|
||||
It is used to generate and finalize reports for each test.
|
||||
|
||||
Args:
|
||||
item: The test item for which the report is being generated.
|
||||
call: The call object from which the test result is retrieved.
|
||||
"""
|
||||
challenge: type[BaseChallenge] = item.cls # type: ignore
|
||||
challenge_id = challenge.info.eval_id
|
||||
|
||||
if challenge_id not in test_reports:
|
||||
test_reports[challenge_id] = make_empty_test_report(challenge.info)
|
||||
|
||||
if call.when == "setup":
|
||||
test_name = item.nodeid.split("::")[1]
|
||||
item.user_properties.append(("test_name", test_name))
|
||||
|
||||
if call.when == "call":
|
||||
add_test_result_to_report(
|
||||
test_reports[challenge_id], item, call, agbenchmark_config
|
||||
)
|
||||
|
||||
|
||||
def timeout_monitor(start_time: int) -> None:
|
||||
"""
|
||||
Function that limits the total execution time of the test suite.
|
||||
This function is supposed to be run in a separate thread and calls `pytest.exit`
|
||||
if the total execution time has exceeded the global timeout.
|
||||
|
||||
Args:
|
||||
start_time (int): The start time of the test suite.
|
||||
"""
|
||||
while time.time() - start_time < GLOBAL_TIMEOUT:
|
||||
time.sleep(1) # check every second
|
||||
|
||||
pytest.exit("Test suite exceeded the global timeout", returncode=1)
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
"""
|
||||
Pytest hook that is called at the start of a test session.
|
||||
|
||||
Sets up and runs a `timeout_monitor` in a separate thread.
|
||||
"""
|
||||
start_time = time.time()
|
||||
t = threading.Thread(target=timeout_monitor, args=(start_time,))
|
||||
t.daemon = True # Daemon threads are abruptly stopped at shutdown
|
||||
t.start()
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
"""
|
||||
Pytest hook that is called at the end of a test session.
|
||||
|
||||
Finalizes and saves the test reports.
|
||||
"""
|
||||
session_finish(agbenchmark_config)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc):
|
||||
n = metafunc.config.getoption("-N")
|
||||
metafunc.parametrize("i_attempt", range(int(n)) if type(n) is str else [0])
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
items: list[pytest.Function], config: pytest.Config
|
||||
) -> None:
|
||||
"""
|
||||
Pytest hook that is called after initial test collection has been performed.
|
||||
Modifies the collected test items based on the agent benchmark configuration,
|
||||
adding the dependency marker and category markers.
|
||||
|
||||
Args:
|
||||
items: The collected test items to be modified.
|
||||
config: The active pytest configuration.
|
||||
"""
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
try:
|
||||
challenges_beaten_in_the_past = json.loads(
|
||||
agbenchmark_config.challenges_already_beaten_file.read_bytes()
|
||||
)
|
||||
except FileNotFoundError:
|
||||
challenges_beaten_in_the_past = {}
|
||||
|
||||
selected_tests: tuple[str] = config.getoption("--test") # type: ignore
|
||||
selected_categories: tuple[str] = config.getoption("--category") # type: ignore
|
||||
|
||||
# Can't use a for-loop to remove items in-place
|
||||
i = 0
|
||||
while i < len(items):
|
||||
item = items[i]
|
||||
assert item.cls and issubclass(item.cls, BaseChallenge)
|
||||
challenge = item.cls
|
||||
challenge_name = challenge.info.name
|
||||
|
||||
if not issubclass(challenge, BaseChallenge):
|
||||
item.warn(
|
||||
pytest.PytestCollectionWarning(
|
||||
f"Non-challenge item collected: {challenge}"
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# --test: remove the test from the set if it's not specifically selected
|
||||
if selected_tests and challenge.info.name not in selected_tests:
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
# Filter challenges for --maintain, --improve, and --explore:
|
||||
# --maintain -> only challenges expected to be passed (= regression tests)
|
||||
# --improve -> only challenges that so far are not passed (reliably)
|
||||
# --explore -> only challenges that have never been passed
|
||||
is_regression_test = rt_tracker.has_regression_test(challenge.info.name)
|
||||
has_been_passed = challenges_beaten_in_the_past.get(challenge.info.name, False)
|
||||
if (
|
||||
(config.getoption("--maintain") and not is_regression_test)
|
||||
or (config.getoption("--improve") and is_regression_test)
|
||||
or (config.getoption("--explore") and has_been_passed)
|
||||
):
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
dependencies = challenge.info.dependencies
|
||||
if (
|
||||
config.getoption("--test")
|
||||
or config.getoption("--no-dep")
|
||||
or config.getoption("--maintain")
|
||||
):
|
||||
# Ignore dependencies:
|
||||
# --test -> user selected specific tests to run, don't care about deps
|
||||
# --no-dep -> ignore dependency relations regardless of test selection
|
||||
# --maintain -> all "regression" tests must pass, so run all of them
|
||||
dependencies = []
|
||||
elif config.getoption("--improve"):
|
||||
# Filter dependencies, keep only deps that are not "regression" tests
|
||||
dependencies = [
|
||||
d for d in dependencies if not rt_tracker.has_regression_test(d)
|
||||
]
|
||||
|
||||
# Set category markers
|
||||
challenge_categories = set(c.value for c in challenge.info.category)
|
||||
for category in challenge_categories:
|
||||
item.add_marker(category)
|
||||
|
||||
# Enforce category selection
|
||||
if selected_categories:
|
||||
if not challenge_categories.intersection(set(selected_categories)):
|
||||
items.remove(item)
|
||||
continue
|
||||
# # Filter dependencies, keep only deps from selected categories
|
||||
# dependencies = [
|
||||
# d for d in dependencies
|
||||
# if not set(d.categories).intersection(set(selected_categories))
|
||||
# ]
|
||||
|
||||
# Skip items in optional categories that are not selected for the subject agent
|
||||
challenge_optional_categories = challenge_categories & set(OPTIONAL_CATEGORIES)
|
||||
if challenge_optional_categories and not (
|
||||
agbenchmark_config.categories
|
||||
and challenge_optional_categories.issubset(
|
||||
set(agbenchmark_config.categories)
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"Skipping {challenge_name}: "
|
||||
f"category {' and '.join(challenge_optional_categories)} is optional, "
|
||||
"and not explicitly selected in the benchmark config."
|
||||
)
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
# Add marker for the DependencyManager
|
||||
item.add_marker(pytest.mark.depends(on=dependencies, name=challenge_name))
|
||||
|
||||
i += 1
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
AGBenchmark's test discovery endpoint for Pytest.
|
||||
|
||||
This module is picked up by Pytest's *_test.py file matching pattern, and all challenge
|
||||
classes in the module that conform to the `Test*` pattern are collected.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from itertools import chain
|
||||
|
||||
from agbenchmark.challenges.builtin import load_builtin_challenges
|
||||
from agbenchmark.challenges.webarena import load_webarena_challenges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_CATEGORY = {}
|
||||
|
||||
# Load challenges and attach them to this module
|
||||
for challenge in chain(load_builtin_challenges(), load_webarena_challenges()):
|
||||
# Attach the Challenge class to this module so it can be discovered by pytest
|
||||
module = importlib.import_module(__name__)
|
||||
setattr(module, challenge.__name__, challenge)
|
||||
|
||||
# Build a map of challenge names and their primary category
|
||||
DATA_CATEGORY[challenge.info.name] = challenge.info.category[0].value
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user