mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-05 12:25:04 -05:00
Compare commits
1 Commits
make-old-w
...
fix/file-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b1f312126 |
73
.github/workflows/classic-autogpt-ci.yml
vendored
73
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,15 +6,11 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -23,22 +19,47 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/original_autogpt
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -50,23 +71,41 @@ jobs:
|
|||||||
git config --global user.name "Auto-GPT-Bot"
|
git config --global user.name "Auto-GPT-Bot"
|
||||||
git config --global user.email "github-bot@agpt.co"
|
git config --global user.email "github-bot@agpt.co"
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- id: get_date
|
- id: get_date
|
||||||
name: Get date
|
name: Get date
|
||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -77,12 +116,12 @@ jobs:
|
|||||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--numprocesses=logical --durations=10 \
|
--numprocesses=logical --durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
tests/unit tests/integration
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -96,11 +135,11 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: autogpt-agent
|
flags: autogpt-agent,${{ runner.os }}
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/original_autogpt/logs/
|
||||||
|
|||||||
36
.github/workflows/classic-autogpts-ci.yml
vendored
36
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,6 +11,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
@@ -19,6 +22,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
@@ -29,9 +35,13 @@ defaults:
|
|||||||
jobs:
|
jobs:
|
||||||
serve-agent-protocol:
|
serve-agent-protocol:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
agent-name: [ original_autogpt ]
|
||||||
|
fail-fast: false
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -45,22 +55,22 @@ jobs:
|
|||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
|
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Run regression tests
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run smoke tests with direct-benchmark
|
|
||||||
run: |
|
run: |
|
||||||
poetry run direct-benchmark run \
|
./run agent start ${{ matrix.agent-name }}
|
||||||
--strategies one_shot \
|
cd ${{ matrix.agent-name }}
|
||||||
--models claude \
|
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||||
--tests ReadFile,WriteFile \
|
poetry run agbenchmark --test=WriteFile
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
AGENT_NAME: ${{ matrix.agent-name }}
|
||||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||||
NONINTERACTIVE_MODE: "true"
|
HELICONE_CACHE_ENABLED: false
|
||||||
CI: true
|
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||||
|
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||||
|
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
194
.github/workflows/classic-benchmark-ci.yml
vendored
194
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,21 +1,17 @@
|
|||||||
name: Classic - Direct Benchmark CI
|
name: Classic - AGBenchmark CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master, dev, ci-test* ]
|
branches: [ master, dev, ci-test* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -27,16 +23,23 @@ defaults:
|
|||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
benchmark-tests:
|
test:
|
||||||
runs-on: ubuntu-latest
|
permissions:
|
||||||
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/benchmark
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -44,88 +47,71 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python ${{ env.min-python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
|
if: runner.os != 'Windows'
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
- name: Install dependencies
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
|
|
||||||
- name: Run basic benchmark tests
|
- name: Run pytest with coverage
|
||||||
run: |
|
run: |
|
||||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
poetry run pytest -vv \
|
||||||
poetry run direct-benchmark run \
|
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--fresh \
|
--durations=10 \
|
||||||
--strategies one_shot \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
--models claude \
|
tests
|
||||||
--tests ReadFile \
|
|
||||||
--json
|
|
||||||
|
|
||||||
echo "Testing WriteFile challenge..."
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--tests WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test category filtering
|
- name: Upload test results to Codecov
|
||||||
run: |
|
if: ${{ !cancelled() }} # Run even if tests fail
|
||||||
echo "Testing coding category..."
|
uses: codecov/test-results-action@v1
|
||||||
poetry run direct-benchmark run \
|
with:
|
||||||
--fresh \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--categories coding \
|
|
||||||
--tests ReadFile,WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test multiple strategies
|
- name: Upload coverage reports to Codecov
|
||||||
run: |
|
uses: codecov/codecov-action@v5
|
||||||
echo "Testing multiple strategies..."
|
with:
|
||||||
poetry run direct-benchmark run \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--fresh \
|
flags: agbenchmark,${{ runner.os }}
|
||||||
--strategies one_shot,plan_execute \
|
|
||||||
--models claude \
|
|
||||||
--tests ReadFile \
|
|
||||||
--parallel 2 \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
# Run regression tests on maintain challenges
|
self-test-with-agent:
|
||||||
regression-tests:
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
timeout-minutes: 45
|
strategy:
|
||||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
matrix:
|
||||||
defaults:
|
agent-name: [forge]
|
||||||
run:
|
fail-fast: false
|
||||||
shell: bash
|
timeout-minutes: 20
|
||||||
working-directory: classic
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -140,23 +126,51 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run regression tests
|
- name: Run regression tests
|
||||||
|
working-directory: classic
|
||||||
run: |
|
run: |
|
||||||
echo "Running regression tests (previously beaten challenges)..."
|
./run agent start ${{ matrix.agent-name }}
|
||||||
poetry run direct-benchmark run \
|
cd ${{ matrix.agent-name }}
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
set +e # Ignore non-zero exit codes and continue execution
|
||||||
--models claude \
|
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||||
--maintain \
|
poetry run agbenchmark --maintain --mock
|
||||||
--parallel 4 \
|
EXIT_CODE=$?
|
||||||
--json
|
set -e # Stop ignoring non-zero exit codes
|
||||||
|
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||||
|
if [ $EXIT_CODE -eq 5 ]; then
|
||||||
|
echo "regression_tests.json is empty."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock"
|
||||||
|
poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||||
|
poetry run agbenchmark --mock --category=data
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||||
|
poetry run agbenchmark --mock --category=coding
|
||||||
|
|
||||||
|
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||||
|
# poetry run agbenchmark --test=WriteFile
|
||||||
|
cd ../benchmark
|
||||||
|
poetry install
|
||||||
|
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||||
|
export BUILD_SKILL_TREE=true
|
||||||
|
|
||||||
|
# poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||||
|
# if [ ! -z "$CHANGED" ]; then
|
||||||
|
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||||
|
# echo "$CHANGED"
|
||||||
|
# exit 1
|
||||||
|
# else
|
||||||
|
# echo "No unstaged changes."
|
||||||
|
# fi
|
||||||
env:
|
env:
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
182
.github/workflows/classic-forge-ci.yml
vendored
182
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,11 +6,13 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -19,38 +21,115 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/forge
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Checkout cassettes
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
env:
|
||||||
|
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
cassette_base_branch="${PR_BASE}"
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
|
||||||
|
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||||
|
cassette_base_branch="master"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||||
|
git fetch origin $cassette_branch
|
||||||
|
git fetch origin $cassette_base_branch
|
||||||
|
|
||||||
|
git checkout $cassette_branch
|
||||||
|
|
||||||
|
# Pick non-conflicting cassette updates from the base branch
|
||||||
|
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||||
|
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||||
|
"synced to upstream branch '$cassette_base_branch'."
|
||||||
|
else
|
||||||
|
git checkout -b $cassette_branch
|
||||||
|
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||||
|
"Using cassettes from '$cassette_base_branch'."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -61,15 +140,12 @@ jobs:
|
|||||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--durations=10 \
|
--durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
forge/forge forge/tests
|
forge
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
# API keys - tests that need these will skip if not available
|
|
||||||
# Secrets are not available to fork PRs (GitHub security feature)
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -83,11 +159,85 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: forge
|
flags: forge,${{ runner.os }}
|
||||||
|
|
||||||
|
- id: setup_git_auth
|
||||||
|
name: Set up git token authentication
|
||||||
|
# Cassettes may be pushed even when tests fail
|
||||||
|
if: success() || failure()
|
||||||
|
run: |
|
||||||
|
config_key="http.${{ github.server_url }}/.extraheader"
|
||||||
|
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||||
|
else
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||||
|
fi
|
||||||
|
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- id: push_cassettes
|
||||||
|
name: Push updated cassettes
|
||||||
|
# For pull requests, push updated cassettes even when tests fail
|
||||||
|
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||||
|
env:
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||||
|
is_pull_request=true
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
else
|
||||||
|
cassette_branch="${{ github.ref_name }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
# Commit & push changes to cassettes if any
|
||||||
|
if ! git diff --quiet; then
|
||||||
|
git add .
|
||||||
|
git commit -m "Auto-update cassettes"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
if [ ! $is_pull_request ]; then
|
||||||
|
cd ../..
|
||||||
|
git add tests/vcr_cassettes
|
||||||
|
git commit -m "Update cassette submodule"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
fi
|
||||||
|
echo "updated=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "updated=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "No cassette changes to commit"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Post Set up git token auth
|
||||||
|
if: steps.setup_git_auth.outcome == 'success'
|
||||||
|
run: |
|
||||||
|
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
|
||||||
|
- name: Apply "behaviour change" label and comment on PR
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
run: |
|
||||||
|
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||||
|
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||||
|
REPO="${{ github.repository }}"
|
||||||
|
|
||||||
|
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||||
|
echo "Adding label and comment..."
|
||||||
|
echo $TOKEN | gh auth login --with-token
|
||||||
|
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||||
|
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/forge/logs/
|
||||||
|
|||||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,9 +7,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
@@ -18,9 +16,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
@@ -31,13 +27,44 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
get-changed-parts:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- id: changes-in
|
||||||
|
name: Determine affected subprojects
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
original_autogpt:
|
||||||
|
- classic/original_autogpt/autogpt/**
|
||||||
|
- classic/original_autogpt/tests/**
|
||||||
|
- classic/original_autogpt/poetry.lock
|
||||||
|
forge:
|
||||||
|
- classic/forge/forge/**
|
||||||
|
- classic/forge/tests/**
|
||||||
|
- classic/forge/poetry.lock
|
||||||
|
benchmark:
|
||||||
|
- classic/benchmark/agbenchmark/**
|
||||||
|
- classic/benchmark/tests/**
|
||||||
|
- classic/benchmark/poetry.lock
|
||||||
|
outputs:
|
||||||
|
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -54,31 +81,42 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Lint
|
# Lint
|
||||||
|
|
||||||
- name: Lint (isort)
|
- name: Lint (isort)
|
||||||
run: poetry run isort --check .
|
run: poetry run isort --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Black)
|
- name: Lint (Black)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run black --check .
|
run: poetry run black --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Flake8)
|
- name: Lint (Flake8)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run flake8 .
|
run: poetry run flake8 .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
types:
|
types:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -95,16 +133,19 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Typecheck
|
# Typecheck
|
||||||
|
|
||||||
- name: Typecheck
|
- name: Typecheck
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run pyright
|
run: poetry run pyright
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|||||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -3,7 +3,6 @@
|
|||||||
classic/original_autogpt/keys.py
|
classic/original_autogpt/keys.py
|
||||||
classic/original_autogpt/*.json
|
classic/original_autogpt/*.json
|
||||||
auto_gpt_workspace/*
|
auto_gpt_workspace/*
|
||||||
.autogpt/
|
|
||||||
*.mpeg
|
*.mpeg
|
||||||
.env
|
.env
|
||||||
# Root .env files
|
# Root .env files
|
||||||
@@ -160,10 +159,6 @@ CURRENT_BULLETIN.md
|
|||||||
|
|
||||||
# AgBenchmark
|
# AgBenchmark
|
||||||
classic/benchmark/agbenchmark/reports/
|
classic/benchmark/agbenchmark/reports/
|
||||||
classic/reports/
|
|
||||||
classic/direct_benchmark/reports/
|
|
||||||
classic/.benchmark_workspaces/
|
|
||||||
classic/direct_benchmark/.benchmark_workspaces/
|
|
||||||
|
|
||||||
# Nodejs
|
# Nodejs
|
||||||
package-lock.json
|
package-lock.json
|
||||||
@@ -182,13 +177,6 @@ autogpt_platform/backend/settings.py
|
|||||||
|
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
**/.claude/settings.local.json
|
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
|
||||||
# Test database
|
|
||||||
test.db
|
|
||||||
|
|
||||||
# Next.js
|
|
||||||
.next
|
|
||||||
|
|||||||
@@ -43,10 +43,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic
|
alias: poetry-install-classic-autogpt
|
||||||
entry: poetry -C classic install
|
entry: poetry -C classic/original_autogpt install
|
||||||
files: ^classic/poetry\.lock$
|
# include forge source (since it's a path dependency)
|
||||||
|
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Forge
|
||||||
|
alias: poetry-install-classic-forge
|
||||||
|
entry: poetry -C classic/forge install
|
||||||
|
files: ^classic/forge/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
|
alias: poetry-install-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark install
|
||||||
|
files: ^classic/benchmark/poetry\.lock$
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
@@ -97,10 +116,26 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
|
|
||||||
- id: isort
|
- id: isort
|
||||||
name: Lint (isort) - Classic
|
name: Lint (isort) - Classic - AutoGPT
|
||||||
alias: isort-classic
|
alias: isort-classic-autogpt
|
||||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Forge
|
||||||
|
alias: isort-classic-forge
|
||||||
|
entry: poetry -P classic/forge run isort -p forge
|
||||||
|
files: ^classic/forge/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Benchmark
|
||||||
|
alias: isort-classic-benchmark
|
||||||
|
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||||
|
files: ^classic/benchmark/
|
||||||
types: [file, python]
|
types: [file, python]
|
||||||
language: system
|
language: system
|
||||||
|
|
||||||
@@ -114,13 +149,26 @@ repos:
|
|||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.0.0
|
rev: 7.0.0
|
||||||
# Use consolidated flake8 config at classic/.flake8
|
# To have flake8 load the config of the individual subprojects, we have to call
|
||||||
|
# them separately.
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
name: Lint (Flake8) - Classic
|
name: Lint (Flake8) - Classic - AutoGPT
|
||||||
alias: flake8-classic
|
alias: flake8-classic-autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||||
args: [--config=classic/.flake8]
|
args: [--config=classic/original_autogpt/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Forge
|
||||||
|
alias: flake8-classic-forge
|
||||||
|
files: ^classic/forge/(forge|tests)/
|
||||||
|
args: [--config=classic/forge/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Benchmark
|
||||||
|
alias: flake8-classic-benchmark
|
||||||
|
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||||
|
args: [--config=classic/benchmark/.flake8]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
@@ -156,10 +204,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: pyright
|
- id: pyright
|
||||||
name: Typecheck - Classic
|
name: Typecheck - Classic - AutoGPT
|
||||||
alias: pyright-classic
|
alias: pyright-classic-autogpt
|
||||||
entry: poetry -C classic run pyright
|
entry: poetry -C classic/original_autogpt run pyright
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||||
|
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Forge
|
||||||
|
alias: pyright-classic-forge
|
||||||
|
entry: poetry -C classic/forge run pyright
|
||||||
|
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Benchmark
|
||||||
|
alias: pyright-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark run pyright
|
||||||
|
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
@@ -1,368 +0,0 @@
|
|||||||
"""Redis Streams consumer for operation completion messages.
|
|
||||||
|
|
||||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
|
||||||
completion notifications (OperationCompleteMessage) from external services
|
|
||||||
(like Agent Generator) and triggers the appropriate stream registry and
|
|
||||||
chat service updates via process_operation_success/process_operation_failure.
|
|
||||||
|
|
||||||
Why Redis Streams instead of RabbitMQ?
|
|
||||||
--------------------------------------
|
|
||||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
|
||||||
queue), Redis Streams was chosen for chat completion notifications because:
|
|
||||||
|
|
||||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
|
||||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
|
||||||
Streams for completion notifications keeps all chat streaming infrastructure
|
|
||||||
in one system, simplifying operations and reducing cross-system coordination.
|
|
||||||
|
|
||||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
|
||||||
allowing consumers to replay missed messages after reconnection. This aligns
|
|
||||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
|
||||||
|
|
||||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
|
||||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
|
||||||
recovering from dead consumers - ideal for the completion callback pattern.
|
|
||||||
|
|
||||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
|
||||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
|
||||||
|
|
||||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
|
||||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
|
||||||
transactional semantics without distributed coordination.
|
|
||||||
|
|
||||||
The consumer uses Redis Streams with consumer groups for reliable message
|
|
||||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
|
||||||
stale pending messages from dead consumers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from redis.exceptions import ResponseError
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteMessage(BaseModel):
|
|
||||||
"""Message format for operation completion notifications."""
|
|
||||||
|
|
||||||
operation_id: str
|
|
||||||
task_id: str
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionConsumer:
|
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
|
||||||
|
|
||||||
This consumer initializes its own Prisma client in start() to ensure
|
|
||||||
database operations work correctly within this async context.
|
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
|
||||||
messages reliably with automatic redelivery on failure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._consumer_task: asyncio.Task | None = None
|
|
||||||
self._running = False
|
|
||||||
self._prisma: Prisma | None = None
|
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""Start the completion consumer."""
|
|
||||||
if self._running:
|
|
||||||
logger.warning("Completion consumer already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create consumer group if it doesn't exist
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xgroup_create(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
id="0",
|
|
||||||
mkstream=True,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Created consumer group '{config.stream_consumer_group}' "
|
|
||||||
f"on stream '{config.stream_completion_name}'"
|
|
||||||
)
|
|
||||||
except ResponseError as e:
|
|
||||||
if "BUSYGROUP" in str(e):
|
|
||||||
logger.debug(
|
|
||||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
|
||||||
logger.info(
|
|
||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_prisma(self) -> Prisma:
|
|
||||||
"""Lazily initialize Prisma client on first use."""
|
|
||||||
if self._prisma is None:
|
|
||||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
|
||||||
self._prisma = Prisma(datasource={"url": database_url})
|
|
||||||
await self._prisma.connect()
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
|
||||||
return self._prisma
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
"""Stop the completion consumer."""
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
if self._consumer_task:
|
|
||||||
self._consumer_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._consumer_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._consumer_task = None
|
|
||||||
|
|
||||||
if self._prisma:
|
|
||||||
await self._prisma.disconnect()
|
|
||||||
self._prisma = None
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
|
||||||
"""Main message consumption loop with retry logic."""
|
|
||||||
max_retries = 10
|
|
||||||
retry_delay = 5 # seconds
|
|
||||||
retry_count = 0
|
|
||||||
block_timeout = 5000 # milliseconds
|
|
||||||
|
|
||||||
while self._running and retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Reset retry count on successful connection
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
# First, claim any stale pending messages from dead consumers
|
|
||||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
|
||||||
# claim them using XAUTOCLAIM
|
|
||||||
try:
|
|
||||||
claimed_result = await redis.xautoclaim(
|
|
||||||
name=config.stream_completion_name,
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
min_idle_time=config.stream_claim_min_idle_ms,
|
|
||||||
start_id="0-0",
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
|
||||||
if claimed_result and len(claimed_result) >= 2:
|
|
||||||
claimed_entries = claimed_result[1]
|
|
||||||
if claimed_entries:
|
|
||||||
logger.info(
|
|
||||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
|
||||||
)
|
|
||||||
for entry_id, data in claimed_entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
|
||||||
|
|
||||||
# Read new messages from the stream
|
|
||||||
messages = await redis.xreadgroup(
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
streams={config.stream_completion_name: ">"},
|
|
||||||
block=block_timeout,
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for stream_name, entries in messages:
|
|
||||||
for entry_id, data in entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Consumer cancelled")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
logger.error(
|
|
||||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
if self._running and retry_count < max_retries:
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, stopping consumer")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _process_entry(
|
|
||||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Process a single stream entry and acknowledge it on success.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
redis: Redis client connection
|
|
||||||
entry_id: The stream entry ID
|
|
||||||
data: The entry data dict
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Handle the message
|
|
||||||
message_data = data.get("data")
|
|
||||||
if message_data:
|
|
||||||
await self._handle_message(
|
|
||||||
message_data.encode()
|
|
||||||
if isinstance(message_data, str)
|
|
||||||
else message_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Acknowledge the message after successful processing
|
|
||||||
await redis.xack(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
entry_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error processing completion message {entry_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Message remains in pending state and will be claimed by
|
|
||||||
# XAUTOCLAIM after min_idle_time expires
|
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
|
||||||
"""Handle a completion message using our own Prisma client."""
|
|
||||||
try:
|
|
||||||
data = orjson.loads(body)
|
|
||||||
message = OperationCompleteMessage(**data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse completion message: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id}, success={message.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find task in registry
|
|
||||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
|
||||||
if task is None:
|
|
||||||
task = await stream_registry.get_task(message.task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
|
||||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guard against empty task fields
|
|
||||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Task has empty critical fields! "
|
|
||||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
|
||||||
f"tool_call_id={task.tool_call_id!r}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.success:
|
|
||||||
await self._handle_success(task, message)
|
|
||||||
else:
|
|
||||||
await self._handle_failure(task, message)
|
|
||||||
|
|
||||||
async def _handle_success(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_success(task, message.result, prisma)
|
|
||||||
|
|
||||||
async def _handle_failure(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_failure(task, message.error, prisma)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
|
||||||
_consumer: ChatCompletionConsumer | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def start_completion_consumer() -> None:
|
|
||||||
"""Start the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer is None:
|
|
||||||
_consumer = ChatCompletionConsumer()
|
|
||||||
await _consumer.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_completion_consumer() -> None:
|
|
||||||
"""Stop the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer:
|
|
||||||
await _consumer.stop()
|
|
||||||
_consumer = None
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_operation_complete(
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
success: bool,
|
|
||||||
result: dict | str | None = None,
|
|
||||||
error: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Publish an operation completion message to Redis Streams.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID that completed.
|
|
||||||
task_id: The task ID associated with the operation.
|
|
||||||
success: Whether the operation succeeded.
|
|
||||||
result: The result data (for success).
|
|
||||||
error: The error message (for failure).
|
|
||||||
"""
|
|
||||||
message = OperationCompleteMessage(
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
success=success,
|
|
||||||
result=result,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xadd(
|
|
||||||
config.stream_completion_name,
|
|
||||||
{"data": message.model_dump_json()},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
logger.info(f"Published completion for operation {operation_id}")
|
|
||||||
@@ -1,344 +0,0 @@
|
|||||||
"""Shared completion handling for operation success and failure.
|
|
||||||
|
|
||||||
This module provides common logic for handling operation completion from both:
|
|
||||||
- The Redis Streams consumer (completion_consumer.py)
|
|
||||||
- The HTTP webhook endpoint (routes.py)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
|
|
||||||
from . import service as chat_service
|
|
||||||
from . import stream_registry
|
|
||||||
from .response_model import StreamError, StreamToolOutputAvailable
|
|
||||||
from .tools.models import ErrorResponse
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that produce agent_json that needs to be saved to library
|
|
||||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
|
||||||
|
|
||||||
# Keys that should be stripped from agent_json when returning in error responses
|
|
||||||
SENSITIVE_KEYS = frozenset(
|
|
||||||
{
|
|
||||||
"api_key",
|
|
||||||
"apikey",
|
|
||||||
"api_secret",
|
|
||||||
"password",
|
|
||||||
"secret",
|
|
||||||
"credentials",
|
|
||||||
"credential",
|
|
||||||
"token",
|
|
||||||
"access_token",
|
|
||||||
"refresh_token",
|
|
||||||
"private_key",
|
|
||||||
"privatekey",
|
|
||||||
"auth",
|
|
||||||
"authorization",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_agent_json(obj: Any) -> Any:
|
|
||||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to sanitize (dict, list, or primitive)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized copy with sensitive keys removed/redacted
|
|
||||||
"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {
|
|
||||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
|
||||||
for k, v in obj.items()
|
|
||||||
}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_sanitize_agent_json(item) for item in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageUpdateError(Exception):
|
|
||||||
"""Raised when updating a tool message in the database fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_tool_message(
|
|
||||||
session_id: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
content: str,
|
|
||||||
prisma_client: Prisma | None,
|
|
||||||
) -> None:
|
|
||||||
"""Update tool message in database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID
|
|
||||||
tool_call_id: The tool call ID to update
|
|
||||||
content: The new content for the message
|
|
||||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The caller should
|
|
||||||
handle this to avoid marking the task as completed with inconsistent state.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if prisma_client:
|
|
||||||
# Use provided Prisma client (for consumer with its own connection)
|
|
||||||
updated_count = await prisma_client.chatmessage.update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={"content": content},
|
|
||||||
)
|
|
||||||
# Check if any rows were updated - 0 means message not found
|
|
||||||
if updated_count == 0:
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use service function (for webhook endpoint)
|
|
||||||
await chat_service._update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=content,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
|
||||||
"""Serialize result to JSON string with sensible defaults.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result to serialize. Can be a dict, list, string,
|
|
||||||
number, boolean, or None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
|
||||||
only when result is explicitly None.
|
|
||||||
"""
|
|
||||||
if isinstance(result, str):
|
|
||||||
return result
|
|
||||||
if result is None:
|
|
||||||
return '{"status": "completed"}'
|
|
||||||
return orjson.dumps(result).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_agent_from_result(
|
|
||||||
result: dict[str, Any],
|
|
||||||
user_id: str | None,
|
|
||||||
tool_name: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Save agent to library if result contains agent_json.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result dict that may contain agent_json
|
|
||||||
user_id: The user ID to save the agent for
|
|
||||||
tool_name: The tool name (create_agent or edit_agent)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated result dict with saved agent details, or original result if no agent_json
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
|
||||||
return result
|
|
||||||
|
|
||||||
agent_json = result.get("agent_json")
|
|
||||||
if not agent_json:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .tools.agent_generator import save_agent_to_library
|
|
||||||
|
|
||||||
is_update = tool_name == "edit_agent"
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
agent_json, user_id, is_update=is_update
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
|
||||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return a response similar to AgentSavedResponse
|
|
||||||
return {
|
|
||||||
"type": "agent_saved",
|
|
||||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
"agent_id": created_graph.id,
|
|
||||||
"agent_name": created_graph.name,
|
|
||||||
"library_agent_id": library_agent.id,
|
|
||||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
|
||||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Return error but don't fail the whole operation
|
|
||||||
# Sanitize agent_json to remove sensitive keys before returning
|
|
||||||
return {
|
|
||||||
"type": "error",
|
|
||||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
|
||||||
"error": str(e),
|
|
||||||
"agent_json": _sanitize_agent_json(agent_json),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_success(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
result: dict | str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion.
|
|
||||||
|
|
||||||
Publishes the result to the stream registry, updates the database,
|
|
||||||
generates LLM continuation, and marks the task as completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that completed
|
|
||||||
result: The result data from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The task will be
|
|
||||||
marked as failed instead of completed to avoid inconsistent state.
|
|
||||||
"""
|
|
||||||
# For agent generation tools, save the agent to library
|
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
|
||||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
|
||||||
|
|
||||||
# Serialize result for output (only substitute default when result is exactly None)
|
|
||||||
result_output = result if result is not None else {"status": "completed"}
|
|
||||||
output_str = (
|
|
||||||
result_output
|
|
||||||
if isinstance(result_output, str)
|
|
||||||
else orjson.dumps(result_output).decode("utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish result to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=task.tool_call_id,
|
|
||||||
toolName=task.tool_name,
|
|
||||||
output=output_str,
|
|
||||||
success=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation in database
|
|
||||||
# If this fails, we must not continue to mark the task as completed
|
|
||||||
result_str = serialize_result(result)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=result_str,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
|
||||||
"marking as failed instead of completed"
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText="Failed to save operation result to database"),
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Generate LLM continuation with streaming
|
|
||||||
try:
|
|
||||||
await chat_service._generate_llm_continuation_with_streaming(
|
|
||||||
session_id=task.session_id,
|
|
||||||
user_id=task.user_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_failure(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
error: str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion.
|
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database with
|
|
||||||
the error response, and marks the task as failed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that failed
|
|
||||||
error: The error message from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
"""
|
|
||||||
error_msg = error or "Operation failed"
|
|
||||||
|
|
||||||
# Publish error to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText=error_msg),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation with error
|
|
||||||
# If this fails, we still continue to mark the task as failed
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=error_msg,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=error_response.model_dump_json(),
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - log but continue with cleanup
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
|
||||||
"continuing with cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
|
||||||
@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream registry configuration for SSE reconnection
|
|
||||||
stream_ttl: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
|
||||||
)
|
|
||||||
stream_max_length: int = Field(
|
|
||||||
default=10000,
|
|
||||||
description="Maximum number of messages to store per stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis Streams configuration for completion consumer
|
|
||||||
stream_completion_name: str = Field(
|
|
||||||
default="chat:completions",
|
|
||||||
description="Redis Stream name for operation completions",
|
|
||||||
)
|
|
||||||
stream_consumer_group: str = Field(
|
|
||||||
default="chat_consumers",
|
|
||||||
description="Consumer group name for completion stream",
|
|
||||||
)
|
|
||||||
stream_claim_min_idle_ms: int = Field(
|
|
||||||
default=60000,
|
|
||||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis key prefixes for stream registry
|
|
||||||
task_meta_prefix: str = Field(
|
|
||||||
default="chat:task:meta:",
|
|
||||||
description="Prefix for task metadata hash keys",
|
|
||||||
)
|
|
||||||
task_stream_prefix: str = Field(
|
|
||||||
default="chat:stream:",
|
|
||||||
description="Prefix for task message stream keys",
|
|
||||||
)
|
|
||||||
task_op_prefix: str = Field(
|
|
||||||
default="chat:task:op:",
|
|
||||||
description="Prefix for operation ID to task ID mapping keys",
|
|
||||||
)
|
|
||||||
internal_api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("internal_api_key", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_internal_api_key(cls, v):
|
|
||||||
"""Get internal API key from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
taskId: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class ActiveStreamInfo(BaseModel):
|
|
||||||
"""Information about an active stream for reconnection."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
|
||||||
operation_id: str # Operation ID for completion tracking
|
|
||||||
tool_name: str # Name of the tool being executed
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteRequest(BaseModel):
|
|
||||||
"""Request model for external completion webhook."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -188,14 +166,13 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns the task_id for reconnection.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -203,28 +180,11 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
logger.info(
|
||||||
# Check if there's an active stream for this session
|
f"Returning session {session_id}: "
|
||||||
active_stream_info = None
|
f"message_count={len(messages)}, "
|
||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
session_id, user_id
|
|
||||||
)
|
)
|
||||||
if active_task:
|
|
||||||
# Filter out the in-progress assistant message from the session response.
|
|
||||||
# The client will receive the complete assistant response through the SSE
|
|
||||||
# stream replay instead, preventing duplicate content.
|
|
||||||
if messages and messages[-1].get("role") == "assistant":
|
|
||||||
messages = messages[:-1]
|
|
||||||
|
|
||||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
|
||||||
# Since we filtered out the cached assistant message, the client needs
|
|
||||||
# the full stream to reconstruct the response.
|
|
||||||
active_stream_info = ActiveStreamInfo(
|
|
||||||
task_id=active_task.task_id,
|
|
||||||
last_message_id="0-0",
|
|
||||||
operation_id=active_task.operation_id,
|
|
||||||
tool_name=active_task.tool_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -232,7 +192,6 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
active_stream=active_stream_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -252,112 +211,49 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
|
||||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
|
||||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
containing the task_id for reconnection.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
|
||||||
task_id = str(uuid_module.uuid4())
|
|
||||||
operation_id = str(uuid_module.uuid4())
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
|
||||||
tool_name="chat",
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
|
||||||
async def run_ai_generation():
|
|
||||||
try:
|
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
|
||||||
|
|
||||||
# Mark task as completed
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
subscriber_queue = None
|
chunk_count = 0
|
||||||
try:
|
first_chunk_type: str | None = None
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
session_id,
|
||||||
task_id=task_id,
|
request.message,
|
||||||
user_id=user_id,
|
is_user_message=request.is_user_message,
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
user_id=user_id,
|
||||||
)
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
|
context=request.context,
|
||||||
if subscriber_queue is None:
|
):
|
||||||
yield StreamFinish().to_sse()
|
if chunk_count < 3:
|
||||||
yield "data: [DONE]\n\n"
|
logger.info(
|
||||||
return
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
# Read from the subscriber queue and yield to SSE
|
"session_id": session_id,
|
||||||
while True:
|
"chunk_type": str(chunk.type),
|
||||||
try:
|
},
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
)
|
||||||
yield chunk.to_sse()
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
# Check for finish signal
|
chunk_count += 1
|
||||||
if isinstance(chunk, StreamFinish):
|
yield chunk.to_sse()
|
||||||
break
|
logger.info(
|
||||||
except asyncio.TimeoutError:
|
"Chat stream completed",
|
||||||
# Send heartbeat to keep connection alive
|
extra={
|
||||||
yield StreamHeartbeat().to_sse()
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
except GeneratorExit:
|
"first_chunk_type": first_chunk_type,
|
||||||
pass # Client disconnected - background task continues
|
},
|
||||||
except Exception as e:
|
)
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
# AI SDK protocol termination
|
||||||
finally:
|
yield "data: [DONE]\n\n"
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
|
||||||
if subscriber_queue is not None:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(
|
|
||||||
task_id, subscriber_queue
|
|
||||||
)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -470,251 +366,6 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Task Streaming (SSE Reconnection) ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}/stream",
|
|
||||||
)
|
|
||||||
async def stream_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
last_message_id: str = Query(
|
|
||||||
default="0-0",
|
|
||||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Reconnect to a long-running task's SSE stream.
|
|
||||||
|
|
||||||
When a long-running operation (like agent generation) starts, the client
|
|
||||||
receives a task_id. If the connection drops, the client can reconnect
|
|
||||||
using this endpoint to resume receiving updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID from the operation_started response.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
|
||||||
"""
|
|
||||||
# Check task existence and expiry before subscribing
|
|
||||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
|
||||||
|
|
||||||
if error_code == "TASK_EXPIRED":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=410,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_EXPIRED",
|
|
||||||
"message": "This operation has expired. Please try again.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if error_code == "TASK_NOT_FOUND":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate ownership if task has an owner
|
|
||||||
if task and task.user_id and user_id != task.user_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail={
|
|
||||||
"code": "ACCESS_DENIED",
|
|
||||||
"message": "You do not have access to this task.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get subscriber queue from stream registry
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id=last_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found or access denied.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait for next chunk with timeout for heartbeats
|
|
||||||
chunk = await asyncio.wait_for(
|
|
||||||
subscriber_queue.get(), timeout=heartbeat_interval
|
|
||||||
)
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_generator(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}",
|
|
||||||
)
|
|
||||||
async def get_task_status(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get the status of a long-running task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID to check.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If task_id is not found or user doesn't have access.
|
|
||||||
"""
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task.user_id and user_id != task.user_id:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task.task_id,
|
|
||||||
"session_id": task.session_id,
|
|
||||||
"status": task.status,
|
|
||||||
"tool_name": task.tool_name,
|
|
||||||
"operation_id": task.operation_id,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== External Completion Webhook ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/operations/{operation_id}/complete",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def complete_operation(
|
|
||||||
operation_id: str,
|
|
||||||
request: OperationCompleteRequest,
|
|
||||||
x_api_key: str | None = Header(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
External completion webhook for long-running operations.
|
|
||||||
|
|
||||||
Called by Agent Generator (or other services) when an operation completes.
|
|
||||||
This triggers the stream registry to publish completion and continue LLM generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID to complete.
|
|
||||||
request: Completion payload with success status and result/error.
|
|
||||||
x_api_key: Internal API key for authentication.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the completion.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If API key is invalid or operation not found.
|
|
||||||
"""
|
|
||||||
# Validate internal API key - reject if not configured or invalid
|
|
||||||
if not config.internal_api_key:
|
|
||||||
logger.error(
|
|
||||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Webhook not available: internal API key not configured",
|
|
||||||
)
|
|
||||||
if x_api_key != config.internal_api_key:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
|
|
||||||
# Find task by operation_id
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Operation {operation_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received completion webhook for operation {operation_id} "
|
|
||||||
f"(task_id={task.task_id}, success={request.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.success:
|
|
||||||
await process_operation_success(task, request.result)
|
|
||||||
else:
|
|
||||||
await process_operation_failure(task, request.error)
|
|
||||||
|
|
||||||
return {"status": "ok", "task_id": task.task_id}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/ttl", status_code=200)
|
|
||||||
async def get_ttl_config() -> dict:
|
|
||||||
"""
|
|
||||||
Get the stream TTL configuration.
|
|
||||||
|
|
||||||
Returns the Time-To-Live settings for chat streams, which determines
|
|
||||||
how long clients can reconnect to an active stream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: TTL configuration with seconds and milliseconds values.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"stream_ttl_seconds": config.stream_ttl,
|
|
||||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from backend.util.exceptions import NotFoundError
|
|||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -1185,9 +1184,8 @@ async def _yield_tool_call(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate operation ID and task ID
|
# Generate operation ID
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
task_id = str(uuid_module.uuid4())
|
|
||||||
|
|
||||||
# Build a user-friendly message based on tool and arguments
|
# Build a user-friendly message based on tool and arguments
|
||||||
if tool_name == "create_agent":
|
if tool_name == "create_agent":
|
||||||
@@ -1230,16 +1228,6 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Wrap session save and task creation in try-except to release lock on failure
|
# Wrap session save and task creation in try-except to release lock on failure
|
||||||
try:
|
try:
|
||||||
# Create task in stream registry for SSE reconnection support
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session.session_id,
|
|
||||||
user_id=session.user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save assistant message with tool_call FIRST (required by LLM)
|
# Save assistant message with tool_call FIRST (required by LLM)
|
||||||
assistant_message = ChatMessage(
|
assistant_message = ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -1261,27 +1249,23 @@ async def _yield_tool_call(
|
|||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
await upsert_chat_session(session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||||
f"for tool {tool_name} in session {session.session_id}"
|
f"in session {session.session_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store task reference in module-level set to prevent GC before completion
|
# Store task reference in module-level set to prevent GC before completion
|
||||||
bg_task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
_execute_long_running_tool_with_streaming(
|
_execute_long_running_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
parameters=arguments,
|
parameters=arguments,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
task_id=task_id,
|
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_background_tasks.add(bg_task)
|
_background_tasks.add(task)
|
||||||
bg_task.add_done_callback(_background_tasks.discard)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# Associate the asyncio task with the stream registry task
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||||
if (
|
if (
|
||||||
@@ -1299,11 +1283,6 @@ async def _yield_tool_call(
|
|||||||
|
|
||||||
# Release the Redis lock since the background task won't be spawned
|
# Release the Redis lock since the background task won't be spawned
|
||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
# Mark stream registry task as failed if it was created
|
|
||||||
try:
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
@@ -1317,7 +1296,6 @@ async def _yield_tool_call(
|
|||||||
message=started_msg,
|
message=started_msg,
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
task_id=task_id, # Include task_id for SSE reconnection
|
|
||||||
).model_dump_json(),
|
).model_dump_json(),
|
||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
@@ -1387,9 +1365,6 @@ async def _execute_long_running_tool(
|
|||||||
|
|
||||||
This function runs independently of the SSE connection, so the operation
|
This function runs independently of the SSE connection, so the operation
|
||||||
survives if the user closes their browser tab.
|
survives if the user closes their browser tab.
|
||||||
|
|
||||||
NOTE: This is the legacy function without stream registry support.
|
|
||||||
Use _execute_long_running_tool_with_streaming for new implementations.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load fresh session (not stale reference)
|
# Load fresh session (not stale reference)
|
||||||
@@ -1442,133 +1417,6 @@ async def _execute_long_running_tool(
|
|||||||
await _mark_operation_completed(tool_call_id)
|
await _mark_operation_completed(tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
async def _execute_long_running_tool_with_streaming(
|
|
||||||
tool_name: str,
|
|
||||||
parameters: dict[str, Any],
|
|
||||||
tool_call_id: str,
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
|
||||||
|
|
||||||
This function runs independently of the SSE connection, publishes progress
|
|
||||||
to the stream registry, and survives if the user closes their browser tab.
|
|
||||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
|
||||||
|
|
||||||
If the external service returns a 202 Accepted (async), this function exits
|
|
||||||
early and lets the Redis Streams completion consumer handle the rest.
|
|
||||||
"""
|
|
||||||
# Track whether we delegated to async processing - if so, the Redis Streams
|
|
||||||
# completion consumer (stream_registry / completion_consumer) will handle cleanup, not us
|
|
||||||
delegated_to_async = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load fresh session (not stale reference)
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
if not session:
|
|
||||||
logger.error(f"Session {session_id} not found for background tool")
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Pass operation_id and task_id to the tool for async processing
|
|
||||||
enriched_parameters = {
|
|
||||||
**parameters,
|
|
||||||
"_operation_id": operation_id,
|
|
||||||
"_task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Execute the actual tool
|
|
||||||
result = await execute_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
parameters=enriched_parameters,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the tool result indicates async processing
|
|
||||||
# (e.g., Agent Generator returned 202 Accepted)
|
|
||||||
try:
|
|
||||||
if isinstance(result.output, dict):
|
|
||||||
result_data = result.output
|
|
||||||
elif result.output:
|
|
||||||
result_data = orjson.loads(result.output)
|
|
||||||
else:
|
|
||||||
result_data = {}
|
|
||||||
if result_data.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Tool {tool_name} delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
|
||||||
f"Redis Streams completion consumer will handle the rest."
|
|
||||||
)
|
|
||||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
|
||||||
# The Redis Streams consumer (completion_consumer) will handle
|
|
||||||
# everything when the external service completes via webhook
|
|
||||||
delegated_to_async = True
|
|
||||||
return
|
|
||||||
except (orjson.JSONDecodeError, TypeError):
|
|
||||||
pass # Not JSON or not async - continue normally
|
|
||||||
|
|
||||||
# Publish tool result to stream registry
|
|
||||||
await stream_registry.publish_chunk(task_id, result)
|
|
||||||
|
|
||||||
# Update the pending message with result
|
|
||||||
result_str = (
|
|
||||||
result.output
|
|
||||||
if isinstance(result.output, str)
|
|
||||||
else orjson.dumps(result.output).decode("utf-8")
|
|
||||||
)
|
|
||||||
await _update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=result_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Background tool {tool_name} completed for session {session_id} "
|
|
||||||
f"(task_id={task_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate LLM continuation and stream chunks to registry
|
|
||||||
await _generate_llm_continuation_with_streaming(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed in stream registry
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="completed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=f"Tool {tool_name} failed: {str(e)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish error to stream registry followed by finish event
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(errorText=str(e)),
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
|
||||||
|
|
||||||
await _update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=error_response.model_dump_json(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed in stream registry
|
|
||||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
|
||||||
finally:
|
|
||||||
# Only cleanup if we didn't delegate to async processing
|
|
||||||
# For async path, the Redis Streams completion consumer handles cleanup
|
|
||||||
if not delegated_to_async:
|
|
||||||
await _mark_operation_completed(tool_call_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_pending_operation(
|
async def _update_pending_operation(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
@@ -1749,128 +1597,3 @@ async def _generate_llm_continuation(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
async def _generate_llm_continuation_with_streaming(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
task_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Generate an LLM response with streaming to the stream registry.
|
|
||||||
|
|
||||||
This is called by background tasks to continue the conversation
|
|
||||||
after a tool result is saved. Chunks are published to the stream registry
|
|
||||||
so reconnecting clients can receive them.
|
|
||||||
"""
|
|
||||||
import uuid as uuid_module
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
|
||||||
await invalidate_session_cache(session_id)
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
if not session:
|
|
||||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build system prompt
|
|
||||||
system_prompt, _ = await _build_system_prompt(user_id)
|
|
||||||
|
|
||||||
# Build messages in OpenAI format
|
|
||||||
messages = session.to_openai_messages()
|
|
||||||
if system_prompt:
|
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
|
||||||
role="system",
|
|
||||||
content=system_prompt,
|
|
||||||
)
|
|
||||||
messages = [system_message] + messages
|
|
||||||
|
|
||||||
# Build extra_body for tracing
|
|
||||||
extra_body: dict[str, Any] = {
|
|
||||||
"posthogProperties": {
|
|
||||||
"environment": settings.config.app_env.value,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if user_id:
|
|
||||||
extra_body["user"] = user_id[:128]
|
|
||||||
extra_body["posthogDistinctId"] = user_id
|
|
||||||
if session_id:
|
|
||||||
extra_body["session_id"] = session_id[:128]
|
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
|
||||||
|
|
||||||
# Generate unique IDs for AI SDK protocol
|
|
||||||
message_id = str(uuid_module.uuid4())
|
|
||||||
text_block_id = str(uuid_module.uuid4())
|
|
||||||
|
|
||||||
# Publish start event
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
|
||||||
|
|
||||||
# Stream the response
|
|
||||||
stream = await client.chat.completions.create(
|
|
||||||
model=config.model,
|
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
|
||||||
extra_body=extra_body,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_content = ""
|
|
||||||
async for chunk in stream:
|
|
||||||
if chunk.choices and chunk.choices[0].delta.content:
|
|
||||||
delta = chunk.choices[0].delta.content
|
|
||||||
assistant_content += delta
|
|
||||||
# Publish delta to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamTextDelta(id=text_block_id, delta=delta),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish end events
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
|
||||||
|
|
||||||
if assistant_content:
|
|
||||||
# Reload session from DB to avoid race condition with user messages
|
|
||||||
fresh_session = await get_chat_session(session_id, user_id)
|
|
||||||
if not fresh_session:
|
|
||||||
logger.error(
|
|
||||||
f"Session {session_id} disappeared during LLM continuation"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save assistant message to database
|
|
||||||
assistant_message = ChatMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=assistant_content,
|
|
||||||
)
|
|
||||||
fresh_session.messages.append(assistant_message)
|
|
||||||
|
|
||||||
# Save to database (not cache) to persist the response
|
|
||||||
await upsert_chat_session(fresh_session)
|
|
||||||
|
|
||||||
# Invalidate cache so next poll/refresh gets fresh data
|
|
||||||
await invalidate_session_cache(session_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Generated streaming LLM continuation for session {session_id} "
|
|
||||||
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Streaming LLM continuation returned empty response for {session_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
# Publish error to stream registry followed by finish event
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
|
||||||
|
|||||||
@@ -1,704 +0,0 @@
|
|||||||
"""Stream registry for managing reconnectable SSE streams.
|
|
||||||
|
|
||||||
This module provides a registry for tracking active streaming tasks and their
|
|
||||||
messages. It uses Redis for all state management (no in-memory state), making
|
|
||||||
pods stateless and horizontally scalable.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
|
||||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
|
||||||
|
|
||||||
Subscribers:
|
|
||||||
1. Replay missed messages from Redis Stream (XREAD)
|
|
||||||
2. Listen for live updates via blocking XREAD
|
|
||||||
3. No in-memory state required on the subscribing pod
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from .config import ChatConfig
|
|
||||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
|
||||||
_local_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
# Track listener tasks per subscriber queue for cleanup
|
|
||||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
|
||||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
|
||||||
|
|
||||||
# Timeout for putting chunks into subscriber queues (seconds)
|
|
||||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
|
||||||
QUEUE_PUT_TIMEOUT = 5.0
|
|
||||||
|
|
||||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
|
||||||
# Returns 1 if status was updated, 0 if already completed/failed
|
|
||||||
COMPLETE_TASK_SCRIPT = """
|
|
||||||
local current = redis.call("HGET", KEYS[1], "status")
|
|
||||||
if current == "running" then
|
|
||||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActiveTask:
|
|
||||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
session_id: str
|
|
||||||
user_id: str | None
|
|
||||||
tool_call_id: str
|
|
||||||
tool_name: str
|
|
||||||
operation_id: str
|
|
||||||
status: Literal["running", "completed", "failed"] = "running"
|
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
asyncio_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_meta_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task metadata."""
|
|
||||||
return f"{config.task_meta_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_stream_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task message stream."""
|
|
||||||
return f"{config.task_stream_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
|
||||||
"""Get Redis key for operation_id to task_id mapping."""
|
|
||||||
return f"{config.task_op_prefix}{operation_id}"
|
|
||||||
|
|
||||||
|
|
||||||
async def create_task(
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
tool_call_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
operation_id: str,
|
|
||||||
) -> ActiveTask:
|
|
||||||
"""Create a new streaming task in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Unique identifier for the task
|
|
||||||
session_id: Chat session ID
|
|
||||||
user_id: User ID (may be None for anonymous)
|
|
||||||
tool_call_id: Tool call ID from the LLM
|
|
||||||
tool_name: Name of the tool being executed
|
|
||||||
operation_id: Operation ID for webhook callbacks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created ActiveTask instance (metadata only)
|
|
||||||
"""
|
|
||||||
task = ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store metadata in Redis
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
|
|
||||||
await redis.hset( # type: ignore[misc]
|
|
||||||
meta_key,
|
|
||||||
mapping={
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
"user_id": user_id or "",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"status": task.status,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
|
||||||
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_chunk(
|
|
||||||
task_id: str,
|
|
||||||
chunk: StreamBaseResponse,
|
|
||||||
) -> str:
|
|
||||||
"""Publish a chunk to Redis Stream.
|
|
||||||
|
|
||||||
All delivery is via Redis Streams - no in-memory state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to publish to
|
|
||||||
chunk: The stream response chunk to publish
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The Redis Stream message ID
|
|
||||||
"""
|
|
||||||
chunk_json = chunk.model_dump_json()
|
|
||||||
message_id = "0-0"
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
|
||||||
raw_id = await redis.xadd(
|
|
||||||
stream_key,
|
|
||||||
{"data": chunk_json},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return message_id
|
|
||||||
|
|
||||||
|
|
||||||
async def subscribe_to_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
last_message_id: str = "0-0",
|
|
||||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
|
||||||
"""Subscribe to a task's stream with replay of missed messages.
|
|
||||||
|
|
||||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to subscribe to
|
|
||||||
user_id: User ID for ownership validation
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
|
||||||
or user doesn't have access
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id:
|
|
||||||
if user_id != task_user_id:
|
|
||||||
logger.warning(
|
|
||||||
f"User {user_id} denied access to task {task_id} "
|
|
||||||
f"owned by {task_user_id}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
|
||||||
|
|
||||||
replayed_count = 0
|
|
||||||
replay_last_id = last_message_id
|
|
||||||
if messages:
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
if "data" in msg_data:
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
await subscriber_queue.put(chunk)
|
|
||||||
replayed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
|
||||||
if task_status == "running":
|
|
||||||
listener_task = asyncio.create_task(
|
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
|
||||||
)
|
|
||||||
# Track listener task for cleanup on unsubscribe
|
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
|
||||||
else:
|
|
||||||
# Task is completed/failed - add finish marker
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
|
||||||
|
|
||||||
return subscriber_queue
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_listener(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
last_replayed_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
|
||||||
|
|
||||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
|
||||||
when messages are published during the gap between replay and subscription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to listen for
|
|
||||||
subscriber_queue: Queue to deliver messages to
|
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
|
||||||
last_delivered_id = last_replayed_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
current_id = last_replayed_id
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Block for up to 30 seconds waiting for new messages
|
|
||||||
# This allows periodic checking if task is still running
|
|
||||||
messages = await redis.xread(
|
|
||||||
{stream_key: current_id}, block=30000, count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
# Timeout - check if task is still running
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
|
||||||
if status and status != "running":
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout delivering finish event for task {task_id}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
continue
|
|
||||||
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
|
|
||||||
if "data" not in msg_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(chunk),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
# Update last delivered ID on successful delivery
|
|
||||||
last_delivered_id = current_id
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Subscriber queue full for task {task_id}, "
|
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
|
||||||
)
|
|
||||||
# Send overflow error with recovery info
|
|
||||||
try:
|
|
||||||
overflow_error = StreamError(
|
|
||||||
errorText="Message delivery timeout - some messages may have been missed",
|
|
||||||
code="QUEUE_OVERFLOW",
|
|
||||||
details={
|
|
||||||
"last_delivered_id": last_delivered_id,
|
|
||||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
subscriber_queue.put_nowait(overflow_error)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# Queue is completely stuck, nothing more we can do
|
|
||||||
logger.error(
|
|
||||||
f"Cannot deliver overflow error for task {task_id}, "
|
|
||||||
"queue completely blocked"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop listening on finish
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
|
||||||
raise # Re-raise to propagate cancellation
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
|
||||||
# On error, send finish to unblock subscriber
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
|
||||||
logger.warning(
|
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# Clean up listener task mapping on exit
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def mark_task_completed(
|
|
||||||
task_id: str,
|
|
||||||
status: Literal["completed", "failed"] = "completed",
|
|
||||||
) -> bool:
|
|
||||||
"""Mark a task as completed and publish finish event.
|
|
||||||
|
|
||||||
This is idempotent - calling multiple times with the same task_id is safe.
|
|
||||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
|
||||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to mark as completed
|
|
||||||
status: Final status ("completed" or "failed")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if task was newly marked completed, False if already completed/failed
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
|
|
||||||
# Atomic compare-and-swap: only update if status is "running"
|
|
||||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
|
||||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
|
||||||
try:
|
|
||||||
await publish_chunk(task_id, StreamFinish())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
|
||||||
"Listeners will detect completion via status polling."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up local task reference if exists
|
|
||||||
_local_tasks.pop(task_id, None)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
|
||||||
"""Find a task by its operation ID.
|
|
||||||
|
|
||||||
Used by webhook callbacks to locate the task to update.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID to search for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
task_id = await redis.get(op_key)
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
|
||||||
return await get_task(task_id_str)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task(task_id: str) -> ActiveTask | None:
|
|
||||||
"""Get a task by its ID from Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task_with_expiry_info(
|
|
||||||
task_id: str,
|
|
||||||
) -> tuple[ActiveTask | None, str | None]:
|
|
||||||
"""Get a task by its ID with expiration detection.
|
|
||||||
|
|
||||||
Returns (task, error_code) where error_code is:
|
|
||||||
- None if task found
|
|
||||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
|
||||||
- "TASK_NOT_FOUND" if neither exists
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask or None, error_code or None)
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
# Check if stream still has data (metadata expired but stream hasn't)
|
|
||||||
stream_len = await redis.xlen(stream_key)
|
|
||||||
if stream_len > 0:
|
|
||||||
return None, "TASK_EXPIRED"
|
|
||||||
return None, "TASK_NOT_FOUND"
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_active_task_for_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> tuple[ActiveTask | None, str]:
|
|
||||||
"""Get the active (running) task for a session, if any.
|
|
||||||
|
|
||||||
Scans Redis for tasks matching the session_id with status="running".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: Session ID to look up
|
|
||||||
user_id: User ID for ownership validation (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
|
||||||
"""
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Scan Redis for task metadata keys
|
|
||||||
cursor = 0
|
|
||||||
tasks_checked = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
cursor, keys = await redis.scan(
|
|
||||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
tasks_checked += 1
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
|
||||||
if not meta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
task_session_id = meta.get("session_id", "")
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
task_id = meta.get("task_id", "")
|
|
||||||
|
|
||||||
if task_session_id == session_id and task_status == "running":
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id and user_id != task_user_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
last_id = "0-0"
|
|
||||||
try:
|
|
||||||
messages = await redis.xrevrange(stream_key, count=1)
|
|
||||||
if messages:
|
|
||||||
msg_id = messages[0][0]
|
|
||||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get last message ID: {e}")
|
|
||||||
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=task_session_id,
|
|
||||||
user_id=task_user_id,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status="running",
|
|
||||||
),
|
|
||||||
last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return None, "0-0"
|
|
||||||
|
|
||||||
|
|
||||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|
||||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_data: Parsed JSON data from Redis
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reconstructed response object, or None if unknown type
|
|
||||||
"""
|
|
||||||
from .response_model import (
|
|
||||||
ResponseType,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamHeartbeat,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
StreamUsage,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map response types to their corresponding classes
|
|
||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
|
||||||
ResponseType.START.value: StreamStart,
|
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
|
||||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
|
||||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
|
||||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
|
||||||
ResponseType.ERROR.value: StreamError,
|
|
||||||
ResponseType.USAGE.value: StreamUsage,
|
|
||||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
|
||||||
}
|
|
||||||
|
|
||||||
chunk_type = chunk_data.get("type")
|
|
||||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
if chunk_class is None:
|
|
||||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return chunk_class(**chunk_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
|
||||||
"""Track the asyncio.Task for a task (local reference only).
|
|
||||||
|
|
||||||
This is just for cleanup purposes - the task state is in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
asyncio_task: The asyncio Task to track
|
|
||||||
"""
|
|
||||||
_local_tasks[task_id] = asyncio_task
|
|
||||||
|
|
||||||
|
|
||||||
async def unsubscribe_from_task(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
) -> None:
|
|
||||||
"""Clean up when a subscriber disconnects.
|
|
||||||
|
|
||||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
|
||||||
to prevent resource leaks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
if listener_entry is None:
|
|
||||||
logger.debug(
|
|
||||||
f"No listener task found for task {task_id} queue {queue_id} "
|
|
||||||
"(may have already completed)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
stored_task_id, listener_task = listener_entry
|
|
||||||
|
|
||||||
if stored_task_id != task_id:
|
|
||||||
logger.warning(
|
|
||||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
|
||||||
f"found {stored_task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if listener_task.done():
|
|
||||||
logger.debug(f"Listener task for task {task_id} already completed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cancel the listener task
|
|
||||||
listener_task.cancel()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for the task to be cancelled with a timeout
|
|
||||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Expected - the task was successfully cancelled
|
|
||||||
pass
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
|
||||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
"customize_agent": CustomizeAgentTool(),
|
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .core import (
|
|||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
customize_template,
|
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
@@ -20,7 +19,6 @@ from .core import (
|
|||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
graph_to_json,
|
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
@@ -38,7 +36,6 @@ __all__ = [
|
|||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
"check_external_service_health",
|
||||||
"customize_template",
|
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
@@ -51,7 +48,6 @@ __all__ = [
|
|||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
"graph_to_json",
|
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from backend.data.graph import (
|
|||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
customize_template_external,
|
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -550,21 +549,15 @@ async def decompose_goal(
|
|||||||
async def generate_agent(
|
async def generate_agent(
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams
|
|
||||||
completion notification)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
|
||||||
and SSE delivery)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -572,13 +565,8 @@ async def generate_agent(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(
|
result = await generate_agent_external(
|
||||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't modify async response
|
|
||||||
if result and result.get("status") == "accepted":
|
|
||||||
return result
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
@@ -752,15 +740,32 @@ async def save_agent_to_library(
|
|||||||
return created_graph, library_agents[0]
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
async def get_agent_as_json(
|
||||||
"""Convert a Graph object to JSON format for the agent generator.
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph object to convert
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -797,41 +802,10 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
|
||||||
agent_id: str, user_id: str | None
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Graph ID or library agent ID
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent as JSON dict or None if not found
|
|
||||||
"""
|
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
|
||||||
try:
|
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return graph_to_json(graph)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | None = None,
|
library_agents: list[AgentSummary] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -844,12 +818,10 @@ async def generate_agent_patch(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -857,43 +829,5 @@ async def generate_agent_patch(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(
|
return await generate_agent_patch_external(
|
||||||
update_request,
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
current_agent,
|
|
||||||
_to_dict_list(library_agents),
|
|
||||||
operation_id,
|
|
||||||
task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Customize a template/marketplace agent using natural language.
|
|
||||||
|
|
||||||
This is used when users want to modify a template or marketplace agent
|
|
||||||
to fit their specific needs before adding it to their library.
|
|
||||||
|
|
||||||
The external Agent Generator service handles:
|
|
||||||
- Understanding the modification request
|
|
||||||
- Applying changes to the template
|
|
||||||
- Fixing and validating the result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for customize_template")
|
|
||||||
return await customize_template_external(
|
|
||||||
template_agent, modification_request, context
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -212,45 +212,24 @@ async def decompose_goal_external(
|
|||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -282,8 +261,6 @@ async def generate_agent_patch_external(
|
|||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
@@ -291,40 +268,21 @@ async def generate_agent_patch_external(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
}
|
}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async update request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -368,77 +326,6 @@ async def generate_agent_patch_external(
|
|||||||
return _create_error_response(error_msg, "unexpected_error")
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def customize_template_external(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to customize a template/marketplace agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
request = modification_request
|
|
||||||
if context:
|
|
||||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"template_agent_json": template_agent,
|
|
||||||
"modification_request": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/template-modification", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator template customization failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
|
||||||
if data.get("type") == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if it's an error passed through
|
|
||||||
if data.get("type") == "error":
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Otherwise return the customized agent JSON
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
"""Get available blocks from the external service.
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -99,10 +98,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -224,12 +219,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
decomposition_result,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -273,19 +263,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if agent_json.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent generation delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent generation started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
|||||||
@@ -1,337 +0,0 @@
|
|||||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
customize_template,
|
|
||||||
get_user_message_for_error,
|
|
||||||
graph_to_json,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "customize_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Customize a marketplace or template agent using natural language. "
|
|
||||||
"Takes an existing agent from the marketplace and modifies it based on "
|
|
||||||
"the user's requirements before adding to their library."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The marketplace agent ID in format 'creator/slug' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer'). "
|
|
||||||
"Get this from find_agent results."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"modifications": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of how to customize the agent. "
|
|
||||||
"Be specific about what changes you want to make."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the customized agent to the user's library. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["agent_id", "modifications"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the customize_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Parse the agent ID to get creator/slug
|
|
||||||
2. Fetch the template agent from the marketplace
|
|
||||||
3. Call customize_template with the modification request
|
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not agent_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
|
||||||
error="missing_agent_id",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not modifications:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please describe how you want to customize this agent.",
|
|
||||||
error="missing_modifications",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
|
||||||
"Expected format is 'creator/agent-name' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
|
||||||
),
|
|
||||||
error="invalid_agent_id_format",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
creator_username, agent_slug = parts
|
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
|
||||||
try:
|
|
||||||
agent_details = await store_db.get_store_agent_details(
|
|
||||||
username=creator_username, agent_name=agent_slug
|
|
||||||
)
|
|
||||||
except AgentNotFoundError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
|
||||||
"Please check the agent ID and try again."
|
|
||||||
),
|
|
||||||
error="agent_not_found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
|
||||||
error="fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not agent_details.store_listing_version_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
|
||||||
"Please try a different agent."
|
|
||||||
),
|
|
||||||
error="no_version_available",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the full agent graph
|
|
||||||
try:
|
|
||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
|
||||||
template_agent = graph_to_json(graph)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
|
||||||
error="graph_fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call customize_template
|
|
||||||
try:
|
|
||||||
result = await customize_template(
|
|
||||||
template_agent=template_agent,
|
|
||||||
modification_request=modifications,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent customization is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent due to a service error. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_service_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent. "
|
|
||||||
"The agent generation service may be unavailable or timed out. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle error response
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
error_msg = result.get("error", "Unknown error")
|
|
||||||
error_type = result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="customize the agent",
|
|
||||||
llm_parse_message=(
|
|
||||||
"The AI had trouble customizing the agent. "
|
|
||||||
"Please try again or simplify your request."
|
|
||||||
),
|
|
||||||
validation_message=(
|
|
||||||
"The customized agent failed validation. "
|
|
||||||
"Please try rephrasing your request."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"customization_failed:{error_type}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle clarifying questions
|
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
|
||||||
questions = result.get("questions") or []
|
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
|
||||||
)
|
|
||||||
questions = []
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information to customize this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
|
||||||
error="unexpected_response_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
customized_agent = result
|
|
||||||
|
|
||||||
agent_name = customized_agent.get(
|
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
|
||||||
)
|
|
||||||
agent_description = customized_agent.get("description", "")
|
|
||||||
nodes = customized_agent.get("nodes")
|
|
||||||
links = customized_agent.get("links")
|
|
||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
|
||||||
f"The customized agent has {node_count} blocks. "
|
|
||||||
f"Review it and call customize_agent with save=true to save it."
|
|
||||||
),
|
|
||||||
agent_json=customized_agent,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to user's library
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
customized_agent, user_id, is_update=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=(
|
|
||||||
f"Customized agent '{created_graph.name}' "
|
|
||||||
f"(based on '{agent_details.agent_name}') "
|
|
||||||
f"has been saved to your library!"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving customized agent: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to save the customized agent. Please try again.",
|
|
||||||
error="save_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -17,7 +17,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -105,10 +104,6 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -154,11 +149,7 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
update_request,
|
update_request, current_agent, library_agents
|
||||||
current_agent,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -178,20 +169,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if result.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent edit delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent edit started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
|
|||||||
has_external_trigger: bool | None = None
|
has_external_trigger: bool | None = None
|
||||||
new_output: bool | None = None
|
new_output: bool | None = None
|
||||||
graph_id: str | None = None
|
graph_id: str | None = None
|
||||||
inputs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Input schema for the agent, including field names, types, and defaults",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class InputValidationErrorResponse(ToolResponseBase):
|
|
||||||
"""Response when run_agent receives unknown input fields."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
|
||||||
unrecognized_fields: list[str] = Field(
|
|
||||||
description="List of input field names that were not recognized"
|
|
||||||
)
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
description="The agent's valid input schema for reference"
|
|
||||||
)
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
# Agent output models
|
||||||
class ExecutionOutputInfo(BaseModel):
|
class ExecutionOutputInfo(BaseModel):
|
||||||
"""Summary of a single execution's outputs."""
|
"""Summary of a single execution's outputs."""
|
||||||
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
The task_id can be used to reconnect to the SSE stream via
|
|
||||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
task_id: str | None = None # For SSE reconnection
|
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class AsyncProcessingResponse(ToolResponseBase):
|
|
||||||
"""Response when an operation has been delegated to async processing.
|
|
||||||
|
|
||||||
This is returned by tools when the external service accepts the request
|
|
||||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
|
||||||
consumer will handle the result when the external service completes.
|
|
||||||
|
|
||||||
The status field is specifically "accepted" to allow the long-running tool
|
|
||||||
handler to detect this response and skip LLM continuation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
|
||||||
status: str = "accepted" # Must be "accepted" for detection
|
|
||||||
operation_id: str | None = None
|
|
||||||
task_id: str | None = None
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from .models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
valid_fields = set(input_properties.keys())
|
|
||||||
|
|
||||||
# Check for unknown input fields
|
|
||||||
unrecognized_fields = provided_inputs - valid_fields
|
|
||||||
if unrecognized_fields:
|
|
||||||
return InputValidationErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
|
||||||
f"Agent was not executed. Please use the correct field names from the schema."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
unrecognized_fields=sorted(unrecognized_fields),
|
|
||||||
inputs=graph.input_schema,
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
|
|||||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
|
||||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
|
||||||
user = setup_test_data["user"]
|
|
||||||
store_submission = setup_test_data["store_submission"]
|
|
||||||
|
|
||||||
tool = RunAgentTool()
|
|
||||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
|
||||||
session = make_session(user_id=user.id)
|
|
||||||
|
|
||||||
# Execute with unknown input field names
|
|
||||||
response = await tool.execute(
|
|
||||||
user_id=user.id,
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
tool_call_id=str(uuid.uuid4()),
|
|
||||||
username_agent_slug=agent_marketplace_id,
|
|
||||||
inputs={
|
|
||||||
"unknown_field": "some value",
|
|
||||||
"another_unknown": "another value",
|
|
||||||
},
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
assert hasattr(response, "output")
|
|
||||||
assert isinstance(response.output, str)
|
|
||||||
result_data = orjson.loads(response.output)
|
|
||||||
|
|
||||||
# Should return input_validation_error type with unrecognized fields
|
|
||||||
assert result_data.get("type") == "input_validation_error"
|
|
||||||
assert "unrecognized_fields" in result_data
|
|
||||||
assert set(result_data["unrecognized_fields"]) == {
|
|
||||||
"another_unknown",
|
|
||||||
"unknown_field",
|
|
||||||
}
|
|
||||||
assert "inputs" in result_data # Contains the valid schema
|
|
||||||
assert "Agent was not executed" in result_data["message"]
|
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
block: Any,
|
block: Any,
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
"""
|
"""
|
||||||
Check if user has required credentials for a block.
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
# Get credential field info from block's input schema
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
|
|||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
effective_field_info = field_info
|
# field_info.provider is a frozenset of acceptable providers
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
# field_info.supported_types is a frozenset of acceptable types
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in effective_field_info.provider
|
if cred.provider in field_info.provider
|
||||||
and cred.type in effective_field_info.supported_types
|
and cred.type in field_info.supported_types
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create a placeholder for the missing credential
|
# Create a placeholder for the missing credential
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
missing_credentials.append(
|
missing_credentials.append(
|
||||||
CredentialsMetaInput(
|
CredentialsMetaInput(
|
||||||
id=field_name,
|
id=field_name,
|
||||||
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
|
# Check credentials
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
user_id, block, input_data
|
user_id, block
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
|
|||||||
@@ -8,12 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
HostScopedCredentials,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -278,14 +273,7 @@ async def match_user_credentials_to_graph(
|
|||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and (
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
cred.type != "oauth2"
|
|
||||||
or _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.type != "host_scoped"
|
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -330,10 +318,19 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: OAuth2Credentials,
|
credential: Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
@@ -342,22 +339,6 @@ def _credential_has_required_scopes(
|
|||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_host(
|
|
||||||
credential: HostScopedCredentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a host-scoped credential matches the host required by the input."""
|
|
||||||
# We need to know the host to match host-scoped credentials to.
|
|
||||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
|
||||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check that credential host matches required host.
|
|
||||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
|
||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
# Use a unique search term to avoid matching other test data
|
|
||||||
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"{unique_term} item number {i}",
|
searchable_text=f"pagination test item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -40,10 +40,6 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
|
||||||
start_completion_consumer,
|
|
||||||
stop_completion_consumer,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
|
||||||
try:
|
|
||||||
await start_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Stop chat completion consumer
|
|
||||||
try:
|
|
||||||
await stop_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -162,16 +162,8 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.query(query, variables)
|
team_id = await self.query(query, variables)
|
||||||
nodes = result["teams"]["nodes"]
|
return team_id["teams"]["nodes"][0]["id"]
|
||||||
|
|
||||||
if not nodes:
|
|
||||||
raise LinearAPIException(
|
|
||||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes[0]["id"]
|
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -248,44 +240,17 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(
|
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||||
self,
|
|
||||||
term: str,
|
|
||||||
max_results: int = 10,
|
|
||||||
team_id: str | None = None,
|
|
||||||
) -> list[Issue]:
|
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues(
|
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||||
$term: String!,
|
searchIssues(term: $term, includeComments: $includeComments) {
|
||||||
$first: Int,
|
|
||||||
$teamId: String
|
|
||||||
) {
|
|
||||||
searchIssues(
|
|
||||||
term: $term,
|
|
||||||
first: $first,
|
|
||||||
teamId: $teamId
|
|
||||||
) {
|
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
createdAt
|
|
||||||
state {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
type
|
|
||||||
}
|
|
||||||
project {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
assignee {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -293,8 +258,7 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"first": max_results,
|
"includeComments": True,
|
||||||
"teamId": team_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue, State
|
from .models import CreateIssueResponse, Issue
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
max_results: int = SchemaField(
|
|
||||||
description="Maximum number of results to return",
|
|
||||||
default=10,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
team_name: str | None = SchemaField(
|
|
||||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
error: str = SchemaField(description="Error message if the search failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
"max_results": 10,
|
|
||||||
"team_name": None,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(
|
|
||||||
id="state1", name="In Progress", type="started"
|
|
||||||
),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(id="state1", name="In Progress", type="started"),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
max_results: int = 10,
|
|
||||||
team_name: str | None = None,
|
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Issue] = await client.try_search_issues(term=term)
|
||||||
# Resolve team name to ID if provided
|
return response
|
||||||
# Raises LinearAPIException with descriptive message if team not found
|
|
||||||
team_id: str | None = None
|
|
||||||
if team_name:
|
|
||||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
|
||||||
|
|
||||||
return await client.try_search_issues(
|
|
||||||
term=term,
|
|
||||||
max_results=max_results,
|
|
||||||
team_id=team_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials,
|
credentials=credentials, term=input_data.term
|
||||||
term=input_data.term,
|
|
||||||
max_results=input_data.max_results,
|
|
||||||
team_name=input_data.team_name,
|
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,21 +36,12 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: str | None = (
|
|
||||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
state: State | None = None
|
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||||
|
logger.info(
|
||||||
|
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||||
|
)
|
||||||
|
|
||||||
with disable_signal_handling():
|
with disable_signal_handling():
|
||||||
stagehand = Stagehand(
|
stagehand = Stagehand(
|
||||||
|
|||||||
@@ -873,13 +873,14 @@ def is_block_auth_configured(
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
|
# First, sync all provider costs to blocks
|
||||||
|
# Imported here to avoid circular import
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
from backend.util.retry import func_retry
|
|
||||||
|
|
||||||
sync_all_provider_costs()
|
sync_all_provider_costs()
|
||||||
|
|
||||||
@func_retry
|
for cls in get_blocks().values():
|
||||||
async def sync_block_to_db(block: Block) -> None:
|
block = cls()
|
||||||
existing_block = await AgentBlock.prisma().find_first(
|
existing_block = await AgentBlock.prisma().find_first(
|
||||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||||
)
|
)
|
||||||
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
|
|||||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
continue
|
||||||
|
|
||||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||||
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_blocks: list[str] = []
|
|
||||||
for cls in get_blocks().values():
|
|
||||||
block = cls()
|
|
||||||
try:
|
|
||||||
await sync_block_to_db(block)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to sync block {block.name} to database: {e}. "
|
|
||||||
"Block is still available in memory.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
failed_blocks.append(block.name)
|
|
||||||
|
|
||||||
if failed_blocks:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
|
||||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.request import parse_url
|
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -397,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
request_host, request_port = _extract_host_from_url(url)
|
parsed_url = urlparse(url)
|
||||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
# Extract hostname without port
|
||||||
|
request_host = parsed_url.hostname
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If a port is specified in credential host, the request host port must match
|
# Simple host matching - exact match or wildcard subdomain match
|
||||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
if self.host == request_host:
|
||||||
return False
|
|
||||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
|
||||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Simple host matching
|
|
||||||
if cred_scope_host == request_host:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if cred_scope_host.startswith("*."):
|
if self.host.startswith("*."):
|
||||||
domain = cred_scope_host[2:] # Remove "*."
|
domain = self.host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -557,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
def _extract_host_from_url(url: str) -> str:
|
||||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
"""Extract host from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = parse_url(url)
|
parsed = urlparse(url)
|
||||||
return parsed.hostname or url, parsed.port
|
return parsed.hostname or url
|
||||||
except Exception:
|
except Exception:
|
||||||
return "", None
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -612,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, _extract_host_from_url(str(value)))
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-standard ports require explicit port in credential host
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_matches_url_with_explicit_port(self):
|
|
||||||
"""Test URL matching with explicit port in credential host."""
|
|
||||||
creds = HostScopedCredentials(
|
|
||||||
provider="custom",
|
|
||||||
host="localhost:8080",
|
|
||||||
headers={"Authorization": SecretStr("Bearer token")},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost/simple")
|
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
# Non-standard ports require explicit port in credential host
|
("localhost", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://localhost:3000/test", False),
|
|
||||||
("localhost:3000", "http://localhost:3000/test", True),
|
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
|
||||||
("[::1]", "http://[::1]/test", True),
|
|
||||||
("[::1]", "http://[::1]:80/test", True),
|
|
||||||
("[::1]", "https://[::1]:443/test", True),
|
|
||||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
|
||||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
|
||||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
|
||||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
@@ -157,7 +157,12 @@ async def validate_url(
|
|||||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||||
"""
|
"""
|
||||||
parsed = parse_url(url)
|
# Canonicalize URL
|
||||||
|
url = url.strip("/ ").replace("\\", "/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
url = f"http://{url}"
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
# Check scheme
|
# Check scheme
|
||||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||||
@@ -215,17 +220,6 @@ async def validate_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url: str) -> URL:
|
|
||||||
"""Canonicalizes and parses a URL string."""
|
|
||||||
url = url.strip("/ ").replace("\\", "/")
|
|
||||||
|
|
||||||
# Ensure scheme is present for proper parsing
|
|
||||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
|
||||||
url = f"http://{url}"
|
|
||||||
|
|
||||||
return urlparse(url)
|
|
||||||
|
|
||||||
|
|
||||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||||
"""
|
"""
|
||||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
assert "id" in result
|
assert "id" in result
|
||||||
@@ -175,9 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(
|
# library_agents defaults to None
|
||||||
"Add a node", current_agent, None, None, None
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
)
|
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
|
||||||
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
|
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
|
||||||
export default function OnboardingPage() {
|
export default function OnboardingPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
|
|||||||
async function redirectToStep() {
|
async function redirectToStep() {
|
||||||
try {
|
try {
|
||||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
if (!shouldShowOnboarding) {
|
if (!shouldShowOnboarding) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
|
|||||||
|
|
||||||
// Handle completed onboarding
|
// Handle completed onboarding
|
||||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { revalidatePath } from "next/cache";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
import { revalidatePath } from "next/cache";
|
||||||
|
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
|
||||||
// Handle the callback to complete the user session login
|
// Handle the callback to complete the user session login
|
||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
if (shouldShowOnboarding) {
|
if (shouldShowOnboarding) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
} else {
|
} else {
|
||||||
next = "/";
|
next = getHomepageRoute(isChatEnabled);
|
||||||
revalidatePath(next, "layout");
|
revalidatePath(next, "layout");
|
||||||
}
|
}
|
||||||
} catch (createUserError) {
|
} catch (createUserError) {
|
||||||
|
|||||||
@@ -1,17 +1,6 @@
|
|||||||
import { OAuthPopupResultMessage } from "./types";
|
import { OAuthPopupResultMessage } from "./types";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
/**
|
|
||||||
* Safely encode a value as JSON for embedding in a script tag.
|
|
||||||
* Escapes characters that could break out of the script context to prevent XSS.
|
|
||||||
*/
|
|
||||||
function safeJsonStringify(value: unknown): string {
|
|
||||||
return JSON.stringify(value)
|
|
||||||
.replace(/</g, "\\u003c")
|
|
||||||
.replace(/>/g, "\\u003e")
|
|
||||||
.replace(/&/g, "\\u0026");
|
|
||||||
}
|
|
||||||
|
|
||||||
// This route is intended to be used as the callback for integration OAuth flows,
|
// This route is intended to be used as the callback for integration OAuth flows,
|
||||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||||
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
|
|||||||
console.debug("Sending message to opener:", message);
|
console.debug("Sending message to opener:", message);
|
||||||
|
|
||||||
// Return a response with the message as JSON and a script to close the window
|
// Return a response with the message as JSON and a script to close the window
|
||||||
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
|
|
||||||
return new NextResponse(
|
return new NextResponse(
|
||||||
`
|
`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<script>
|
<script>
|
||||||
window.opener.postMessage(${safeJsonStringify(message)});
|
window.opener.postMessage(${JSON.stringify(message)});
|
||||||
window.close();
|
window.close();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
|
import { useRef } from "react";
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
@@ -69,16 +70,41 @@ export function useCopilotShell() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const stopStream = useChatStore((s) => s.stopStream);
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||||
|
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||||
|
|
||||||
function handleSessionClick(sessionId: string) {
|
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||||
if (sessionId === currentSessionId) return;
|
|
||||||
|
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
async function stopCurrentStream() {
|
||||||
if (currentSessionId) {
|
if (!currentSessionId) return;
|
||||||
|
|
||||||
|
setIsSwitchingSession(true);
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const unsubscribe = onStreamComplete((completedId) => {
|
||||||
|
if (completedId === currentSessionId) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}, 3000);
|
||||||
stopStream(currentSessionId);
|
stopStream(currentSessionId);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||||
|
});
|
||||||
|
setIsSwitchingSession(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectSession(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
@@ -88,12 +114,7 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChatClick() {
|
function startNewChat() {
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
|
||||||
if (currentSessionId) {
|
|
||||||
stopStream(currentSessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -102,6 +123,32 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleSessionClick(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
selectSession(sessionId);
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
selectSession(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleNewChatClick() {
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
startNewChat();
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
startNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
|
|||||||
|
|
||||||
export function getQuickActions(): string[] {
|
export function getQuickActions(): string[] {
|
||||||
return [
|
return [
|
||||||
"I don't know where to start, just ask me stuff",
|
"Show me what I can automate",
|
||||||
"I do the same thing every week and it's killing me",
|
"Design a custom workflow",
|
||||||
"Help me find where I'm wasting my time",
|
"Help me with content creation",
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getInputPlaceholder(width?: number) {
|
|
||||||
if (!width) return "What's your role and what eats up most of your day?";
|
|
||||||
|
|
||||||
if (width < 500) {
|
|
||||||
return "I'm a chef and I hate...";
|
|
||||||
}
|
|
||||||
if (width <= 1080) {
|
|
||||||
return "What's your role and what eats up most of your day?";
|
|
||||||
}
|
|
||||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,6 @@
|
|||||||
"use client";
|
import type { ReactNode } from "react";
|
||||||
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
|
||||||
import { Flag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { type ReactNode } from "react";
|
|
||||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||||
|
|
||||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||||
return (
|
return <CopilotShell>{children}</CopilotShell>;
|
||||||
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
|
||||||
<CopilotShell>{children}</CopilotShell>
|
|
||||||
</FeatureFlagPage>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { getInputPlaceholder } from "./helpers";
|
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
@@ -16,25 +14,14 @@ export default function CopilotPage() {
|
|||||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
|
const {
|
||||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
greetingName,
|
||||||
getInputPlaceholder(),
|
quickActions,
|
||||||
);
|
isLoading,
|
||||||
|
hasSession,
|
||||||
useEffect(() => {
|
initialPrompt,
|
||||||
const handleResize = () => {
|
isReady,
|
||||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
} = state;
|
||||||
};
|
|
||||||
|
|
||||||
handleResize();
|
|
||||||
|
|
||||||
window.addEventListener("resize", handleResize);
|
|
||||||
return () => window.removeEventListener("resize", handleResize);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
|
||||||
state;
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
@@ -42,6 +29,8 @@ export default function CopilotPage() {
|
|||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
|
if (!isReady) return null;
|
||||||
|
|
||||||
if (hasSession) {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
@@ -92,7 +81,7 @@ export default function CopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<div className="mx-auto max-w-2xl">
|
<div className="mx-auto max-w-2xl">
|
||||||
@@ -109,25 +98,25 @@ export default function CopilotPage() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<div className="mx-auto max-w-3xl">
|
<div className="mx-auto max-w-2xl">
|
||||||
<Text
|
<Text
|
||||||
variant="h3"
|
variant="h3"
|
||||||
className="mb-1 !text-[1.375rem] text-zinc-700"
|
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||||
>
|
>
|
||||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||||
</Text>
|
</Text>
|
||||||
<Text variant="h3" className="mb-8 !font-normal">
|
<Text variant="h3" className="mb-8 !font-normal">
|
||||||
Tell me about your work — I'll find what to automate.
|
What do you want to automate?
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<div className="mb-6">
|
<div className="mb-6">
|
||||||
<ChatInput
|
<ChatInput
|
||||||
onSend={startChatWithPrompt}
|
onSend={startChatWithPrompt}
|
||||||
placeholder={inputPlaceholder}
|
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
{quickActions.map((action) => (
|
{quickActions.map((action) => (
|
||||||
<Button
|
<Button
|
||||||
key={action}
|
key={action}
|
||||||
@@ -135,7 +124,7 @@ export default function CopilotPage() {
|
|||||||
variant="outline"
|
variant="outline"
|
||||||
size="small"
|
size="small"
|
||||||
onClick={() => handleQuickAction(action)}
|
onClick={() => handleQuickAction(action)}
|
||||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||||
>
|
>
|
||||||
{action}
|
{action}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -3,11 +3,18 @@ import {
|
|||||||
postV2CreateSession,
|
postV2CreateSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
|
import {
|
||||||
|
Flag,
|
||||||
|
type FlagValues,
|
||||||
|
useGetFlag,
|
||||||
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
@@ -26,6 +33,22 @@ export function useCopilotPage() {
|
|||||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
|
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||||
|
useEffect(() => {
|
||||||
|
if (isLoggedIn) {
|
||||||
|
completeStep("VISIT_COPILOT");
|
||||||
|
}
|
||||||
|
}, [completeStep, isLoggedIn]);
|
||||||
|
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const flags = useFlags<FlagValues>();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
const isFlagReady =
|
||||||
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
@@ -35,8 +58,11 @@ export function useCopilotPage() {
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn) completeStep("VISIT_COPILOT");
|
if (!isFlagReady) return;
|
||||||
}, [completeStep, isLoggedIn]);
|
if (isChatEnabled === false) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
|
}
|
||||||
|
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
@@ -90,6 +116,7 @@ export function useCopilotPage() {
|
|||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
hasSession,
|
hasSession,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { Suspense } from "react";
|
import { Suspense } from "react";
|
||||||
import { getErrorDetails } from "./helpers";
|
import { getErrorDetails } from "./helpers";
|
||||||
@@ -9,6 +11,8 @@ function ErrorPageContent() {
|
|||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const errorMessage = searchParams.get("message");
|
const errorMessage = searchParams.get("message");
|
||||||
const errorDetails = getErrorDetails(errorMessage);
|
const errorDetails = getErrorDetails(errorMessage);
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
function handleRetry() {
|
function handleRetry() {
|
||||||
// Auth-related errors should redirect to login
|
// Auth-related errors should redirect to login
|
||||||
@@ -26,7 +30,7 @@ function ErrorPageContent() {
|
|||||||
}, 2000);
|
}, 2000);
|
||||||
} else {
|
} else {
|
||||||
// For server/network errors, go to home
|
// For server/network errors, go to home
|
||||||
window.location.href = "/";
|
window.location.href = homepageRoute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { loginFormSchema } from "@/types/auth";
|
import { loginFormSchema } from "@/types/auth";
|
||||||
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -20,15 +22,17 @@ export function useLoginPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isLoggingIn) {
|
if (isLoggedIn && !isLoggingIn) {
|
||||||
router.push(nextUrl || "/");
|
router.push(nextUrl || homepageRoute);
|
||||||
}
|
}
|
||||||
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||||
resolver: zodResolver(loginFormSchema),
|
resolver: zodResolver(loginFormSchema),
|
||||||
@@ -94,7 +98,7 @@ export function useLoginPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer URL's next parameter, then use backend-determined route
|
// Prefer URL's next parameter, then use backend-determined route
|
||||||
router.replace(nextUrl || result.next || "/");
|
router.replace(nextUrl || result.next || homepageRoute);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title:
|
title:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { signupFormSchema } from "@/types/auth";
|
import { signupFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
@@ -58,8 +59,10 @@ export async function signup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return { success: true, next };
|
return { success: true, next };
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -20,15 +22,17 @@ export function useSignupPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isSigningUp) {
|
if (isLoggedIn && !isSigningUp) {
|
||||||
router.push(nextUrl || "/");
|
router.push(nextUrl || homepageRoute);
|
||||||
}
|
}
|
||||||
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||||
resolver: zodResolver(signupFormSchema),
|
resolver: zodResolver(signupFormSchema),
|
||||||
@@ -129,7 +133,7 @@ export function useSignupPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||||
const redirectTo = nextUrl || result.next || "/";
|
const redirectTo = nextUrl || result.next || homepageRoute;
|
||||||
router.replace(redirectTo);
|
router.replace(redirectTo);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
import { environment } from "@/services/environment";
|
|
||||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
|
||||||
import { NextRequest } from "next/server";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SSE Proxy for task stream reconnection.
|
|
||||||
*
|
|
||||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
|
||||||
* background task's stream. It replays missed messages from Redis Streams and
|
|
||||||
* subscribes to live updates if the task is still running.
|
|
||||||
*
|
|
||||||
* Client contract:
|
|
||||||
* 1. When receiving an operation_started event, store the task_id
|
|
||||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
|
||||||
* 3. Messages are replayed from the last_message_id position
|
|
||||||
* 4. Stream ends when "finish" event is received
|
|
||||||
*/
|
|
||||||
export async function GET(
|
|
||||||
request: NextRequest,
|
|
||||||
{ params }: { params: Promise<{ taskId: string }> },
|
|
||||||
) {
|
|
||||||
const { taskId } = await params;
|
|
||||||
const searchParams = request.nextUrl.searchParams;
|
|
||||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Get auth token from server-side session
|
|
||||||
const token = await getServerAuthToken();
|
|
||||||
|
|
||||||
// Build backend URL
|
|
||||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
|
||||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
|
||||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
|
||||||
|
|
||||||
// Forward request to backend with auth header
|
|
||||||
const headers: Record<string, string> = {
|
|
||||||
Accept: "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
Connection: "keep-alive",
|
|
||||||
};
|
|
||||||
|
|
||||||
if (token) {
|
|
||||||
headers["Authorization"] = `Bearer ${token}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(streamUrl.toString(), {
|
|
||||||
method: "GET",
|
|
||||||
headers,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const error = await response.text();
|
|
||||||
return new Response(error, {
|
|
||||||
status: response.status,
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the SSE stream directly
|
|
||||||
return new Response(response.body, {
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache, no-transform",
|
|
||||||
Connection: "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Task stream proxy error:", error);
|
|
||||||
return new Response(
|
|
||||||
JSON.stringify({
|
|
||||||
error: "Failed to connect to task stream",
|
|
||||||
detail: error instanceof Error ? error.message : String(error),
|
|
||||||
}),
|
|
||||||
{
|
|
||||||
status: 500,
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
|
|||||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||||
return {
|
return {
|
||||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||||
|
isChatEnabled: status.is_chat_enabled,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -917,28 +917,6 @@
|
|||||||
"security": [{ "HTTPBearerJWT": [] }]
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/config/ttl": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Get Ttl Config",
|
|
||||||
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
|
|
||||||
"operationId": "getV2GetTtlConfig",
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Response Getv2Getttlconfig"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/health": {
|
"/api/chat/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -961,63 +939,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/operations/{operation_id}/complete": {
|
|
||||||
"post": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Complete Operation",
|
|
||||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
|
||||||
"operationId": "postV2CompleteOperation",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "operation_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Operation Id" }
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "x-api-key",
|
|
||||||
"in": "header",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "X-Api-Key"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"requestBody": {
|
|
||||||
"required": true,
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": true,
|
|
||||||
"title": "Response Postv2Completeoperation"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/sessions": {
|
"/api/chat/sessions": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -1101,7 +1022,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Get Session",
|
"summary": "Get Session",
|
||||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||||
"operationId": "getV2GetSession",
|
"operationId": "getV2GetSession",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1236,7 +1157,7 @@
|
|||||||
"post": {
|
"post": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Stream Chat Post",
|
"summary": "Stream Chat Post",
|
||||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||||
"operationId": "postV2StreamChatPost",
|
"operationId": "postV2StreamChatPost",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1274,94 +1195,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/tasks/{task_id}": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Get Task Status",
|
|
||||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
|
||||||
"operationId": "getV2GetTaskStatus",
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "task_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Task Id" }
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": true,
|
|
||||||
"title": "Response Getv2Gettaskstatus"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/tasks/{task_id}/stream": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Stream Task",
|
|
||||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
|
|
||||||
"operationId": "getV2StreamTask",
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "task_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Task Id" }
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "last_message_id",
|
|
||||||
"in": "query",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
"default": "0-0",
|
|
||||||
"title": "Last Message Id"
|
|
||||||
},
|
|
||||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": { "application/json": { "schema": {} } }
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/credits": {
|
"/api/credits": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v1", "credits"],
|
"tags": ["v1", "credits"],
|
||||||
@@ -6335,18 +6168,6 @@
|
|||||||
"title": "AccuracyTrendsResponse",
|
"title": "AccuracyTrendsResponse",
|
||||||
"description": "Response model for accuracy trends and alerts."
|
"description": "Response model for accuracy trends and alerts."
|
||||||
},
|
},
|
||||||
"ActiveStreamInfo": {
|
|
||||||
"properties": {
|
|
||||||
"task_id": { "type": "string", "title": "Task Id" },
|
|
||||||
"last_message_id": { "type": "string", "title": "Last Message Id" },
|
|
||||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
|
||||||
"tool_name": { "type": "string", "title": "Tool Name" }
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
|
|
||||||
"title": "ActiveStreamInfo",
|
|
||||||
"description": "Information about an active stream for reconnection."
|
|
||||||
},
|
|
||||||
"AddUserCreditsResponse": {
|
"AddUserCreditsResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||||
@@ -9002,27 +8823,6 @@
|
|||||||
],
|
],
|
||||||
"title": "OnboardingStep"
|
"title": "OnboardingStep"
|
||||||
},
|
},
|
||||||
"OperationCompleteRequest": {
|
|
||||||
"properties": {
|
|
||||||
"success": { "type": "boolean", "title": "Success" },
|
|
||||||
"result": {
|
|
||||||
"anyOf": [
|
|
||||||
{ "additionalProperties": true, "type": "object" },
|
|
||||||
{ "type": "string" },
|
|
||||||
{ "type": "null" }
|
|
||||||
],
|
|
||||||
"title": "Result"
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Error"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["success"],
|
|
||||||
"title": "OperationCompleteRequest",
|
|
||||||
"description": "Request model for external completion webhook."
|
|
||||||
},
|
|
||||||
"Pagination": {
|
"Pagination": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"total_items": {
|
"total_items": {
|
||||||
@@ -9878,12 +9678,6 @@
|
|||||||
"items": { "additionalProperties": true, "type": "object" },
|
"items": { "additionalProperties": true, "type": "object" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Messages"
|
"title": "Messages"
|
||||||
},
|
|
||||||
"active_stream": {
|
|
||||||
"anyOf": [
|
|
||||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
|
||||||
{ "type": "null" }
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@@ -1,15 +1,27 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
|
|
||||||
export default function Page() {
|
export default function Page() {
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
const isFlagReady =
|
||||||
|
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(
|
||||||
router.replace("/copilot");
|
function redirectToHomepage() {
|
||||||
}, [router]);
|
if (!isFlagReady) return;
|
||||||
|
router.replace(homepageRoute);
|
||||||
|
},
|
||||||
|
[homepageRoute, isFlagReady, router],
|
||||||
|
);
|
||||||
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,7 +104,28 @@ export function FileInput(props: Props) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const getFileLabelFromValue = (val: string) => {
|
const getFileLabelFromValue = (val: unknown): string => {
|
||||||
|
// Handle object format from external API: { name, type, size, data }
|
||||||
|
if (val && typeof val === "object") {
|
||||||
|
const obj = val as Record<string, unknown>;
|
||||||
|
if (typeof obj.name === "string") {
|
||||||
|
return getFileLabel(obj.name, (obj.type as string) || "");
|
||||||
|
}
|
||||||
|
if (typeof obj.type === "string") {
|
||||||
|
const mimeParts = obj.type.split("/");
|
||||||
|
if (mimeParts.length > 1) {
|
||||||
|
return `${mimeParts[1].toUpperCase()} file`;
|
||||||
|
}
|
||||||
|
return `${obj.type} file`;
|
||||||
|
}
|
||||||
|
return "File";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle string values (data URIs or file paths)
|
||||||
|
if (typeof val !== "string") {
|
||||||
|
return "File";
|
||||||
|
}
|
||||||
|
|
||||||
if (val.startsWith("data:")) {
|
if (val.startsWith("data:")) {
|
||||||
const matches = val.match(/^data:([^;]+);/);
|
const matches = val.match(/^data:([^;]+);/);
|
||||||
if (matches?.[1]) {
|
if (matches?.[1]) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||||
|
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
@@ -24,8 +25,8 @@ export function Chat({
|
|||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
const { urlSessionId } = useCopilotSessionId();
|
const { urlSessionId } = useCopilotSessionId();
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
|
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||||
const {
|
const {
|
||||||
session,
|
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
isCreating,
|
isCreating,
|
||||||
@@ -37,18 +38,6 @@ export function Chat({
|
|||||||
startPollingForOperation,
|
startPollingForOperation,
|
||||||
} = useChat({ urlSessionId });
|
} = useChat({ urlSessionId });
|
||||||
|
|
||||||
// Extract active stream info for reconnection
|
|
||||||
const activeStream = (
|
|
||||||
session as {
|
|
||||||
active_stream?: {
|
|
||||||
task_id: string;
|
|
||||||
last_message_id: string;
|
|
||||||
operation_id: string;
|
|
||||||
tool_name: string;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
)?.active_stream;
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!onSessionNotFound) return;
|
if (!onSessionNotFound) return;
|
||||||
if (!urlSessionId) return;
|
if (!urlSessionId) return;
|
||||||
@@ -64,7 +53,8 @@ export function Chat({
|
|||||||
isCreating,
|
isCreating,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
const shouldShowLoader =
|
||||||
|
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
@@ -76,19 +66,21 @@ export function Chat({
|
|||||||
<div className="flex flex-col items-center gap-3">
|
<div className="flex flex-col items-center gap-3">
|
||||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
Loading your chat...
|
{isSwitchingSession
|
||||||
|
? "Switching chat..."
|
||||||
|
: "Loading your chat..."}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Error State */}
|
{/* Error State */}
|
||||||
{error && !isLoading && (
|
{error && !isLoading && !isSwitchingSession && (
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Session Content */}
|
{/* Session Content */}
|
||||||
{sessionId && !isLoading && !error && (
|
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||||
<ChatContainer
|
<ChatContainer
|
||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
@@ -96,16 +88,6 @@ export function Chat({
|
|||||||
className="flex-1"
|
className="flex-1"
|
||||||
onStreamingChange={onStreamingChange}
|
onStreamingChange={onStreamingChange}
|
||||||
onOperationStarted={startPollingForOperation}
|
onOperationStarted={startPollingForOperation}
|
||||||
activeStream={
|
|
||||||
activeStream
|
|
||||||
? {
|
|
||||||
taskId: activeStream.task_id,
|
|
||||||
lastMessageId: activeStream.last_message_id,
|
|
||||||
operationId: activeStream.operation_id,
|
|
||||||
toolName: activeStream.tool_name,
|
|
||||||
}
|
|
||||||
: undefined
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -1,159 +0,0 @@
|
|||||||
# SSE Reconnection Contract for Long-Running Operations
|
|
||||||
|
|
||||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
When a user triggers a long-running operation (like agent generation), the backend:
|
|
||||||
|
|
||||||
1. Spawns a background task that survives SSE disconnections
|
|
||||||
2. Returns an `operation_started` response with a `task_id`
|
|
||||||
3. Stores stream messages in Redis Streams for replay
|
|
||||||
|
|
||||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
|
||||||
|
|
||||||
## Client-Side Flow
|
|
||||||
|
|
||||||
### 1. Receiving Operation Started
|
|
||||||
|
|
||||||
When you receive an `operation_started` tool response:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// The response includes a task_id for reconnection
|
|
||||||
{
|
|
||||||
type: "operation_started",
|
|
||||||
tool_name: "generate_agent",
|
|
||||||
operation_id: "uuid-...",
|
|
||||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
|
||||||
message: "Operation started. You can close this tab."
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Storing Task Info
|
|
||||||
|
|
||||||
Use the chat store to track the active task:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
import { useChatStore } from "./chat-store";
|
|
||||||
|
|
||||||
// When operation_started is received:
|
|
||||||
useChatStore.getState().setActiveTask(sessionId, {
|
|
||||||
taskId: response.task_id,
|
|
||||||
operationId: response.operation_id,
|
|
||||||
toolName: response.tool_name,
|
|
||||||
lastMessageId: "0",
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Reconnecting to a Task
|
|
||||||
|
|
||||||
To reconnect (e.g., after page refresh or tab reopen):
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
|
||||||
|
|
||||||
// Check if there's an active task for this session
|
|
||||||
const activeTask = getActiveTask(sessionId);
|
|
||||||
|
|
||||||
if (activeTask) {
|
|
||||||
// Reconnect to the task stream
|
|
||||||
await reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeTask.taskId,
|
|
||||||
activeTask.lastMessageId, // Resume from last position
|
|
||||||
(chunk) => {
|
|
||||||
// Handle incoming chunks
|
|
||||||
console.log("Received chunk:", chunk);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Tracking Message Position
|
|
||||||
|
|
||||||
To enable precise replay, update the last message ID as chunks arrive:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
|
||||||
|
|
||||||
function handleChunk(chunk: StreamChunk) {
|
|
||||||
// If chunk has an index/id, track it
|
|
||||||
if (chunk.idx !== undefined) {
|
|
||||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
### Task Stream Reconnection
|
|
||||||
|
|
||||||
```
|
|
||||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
|
||||||
```
|
|
||||||
|
|
||||||
- `taskId`: The task ID from `operation_started`
|
|
||||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
|
||||||
|
|
||||||
Returns: SSE stream of missed messages + live updates
|
|
||||||
|
|
||||||
## Chunk Types
|
|
||||||
|
|
||||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
|
||||||
|
|
||||||
| Type | Description |
|
|
||||||
| ----------------------- | ----------------------- |
|
|
||||||
| `start` | Message lifecycle start |
|
|
||||||
| `text-delta` | Streaming text content |
|
|
||||||
| `text-end` | Text block completed |
|
|
||||||
| `tool-output-available` | Tool result available |
|
|
||||||
| `finish` | Stream completed |
|
|
||||||
| `error` | Error occurred |
|
|
||||||
|
|
||||||
## Error Handling
|
|
||||||
|
|
||||||
If reconnection fails:
|
|
||||||
|
|
||||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
|
||||||
2. Fall back to polling the session for final state
|
|
||||||
3. Show appropriate UI message to user
|
|
||||||
|
|
||||||
## Persistence Considerations
|
|
||||||
|
|
||||||
For robust reconnection across browser restarts:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// Store in localStorage/sessionStorage
|
|
||||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
|
||||||
|
|
||||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
|
||||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
|
||||||
tasks[sessionId] = task;
|
|
||||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
|
||||||
}
|
|
||||||
|
|
||||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
|
||||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Backend Configuration
|
|
||||||
|
|
||||||
The following backend settings affect reconnection behavior:
|
|
||||||
|
|
||||||
| Setting | Default | Description |
|
|
||||||
| ------------------- | ------- | ---------------------------------- |
|
|
||||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
|
||||||
| `stream_max_length` | 1000 | Max messages per stream |
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
To test reconnection locally:
|
|
||||||
|
|
||||||
1. Start a long-running operation (e.g., agent generation)
|
|
||||||
2. Note the `task_id` from the `operation_started` response
|
|
||||||
3. Close the browser tab
|
|
||||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
|
||||||
5. Verify that missed messages are replayed
|
|
||||||
|
|
||||||
See the main README for full local development setup.
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
/**
|
|
||||||
* Constants for the chat system.
|
|
||||||
*
|
|
||||||
* Centralizes magic strings and values used across chat components.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// LocalStorage keys
|
|
||||||
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
|
||||||
|
|
||||||
// Redis Stream IDs
|
|
||||||
export const INITIAL_MESSAGE_ID = "0";
|
|
||||||
export const INITIAL_STREAM_ID = "0-0";
|
|
||||||
|
|
||||||
// TTL values (in milliseconds)
|
|
||||||
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
|
||||||
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
|
||||||
@@ -1,12 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
import {
|
|
||||||
ACTIVE_TASK_TTL_MS,
|
|
||||||
COMPLETED_STREAM_TTL_MS,
|
|
||||||
INITIAL_STREAM_ID,
|
|
||||||
STORAGE_KEY_ACTIVE_TASKS,
|
|
||||||
} from "./chat-constants";
|
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -14,59 +8,15 @@ import type {
|
|||||||
StreamResult,
|
StreamResult,
|
||||||
StreamStatus,
|
StreamStatus,
|
||||||
} from "./chat-types";
|
} from "./chat-types";
|
||||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
import { executeStream } from "./stream-executor";
|
||||||
|
|
||||||
export interface ActiveTaskInfo {
|
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||||
taskId: string;
|
|
||||||
sessionId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
startedAt: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Load active tasks from localStorage */
|
|
||||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
|
||||||
if (typeof window === "undefined") return new Map();
|
|
||||||
try {
|
|
||||||
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
|
||||||
if (!stored) return new Map();
|
|
||||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
|
||||||
const now = Date.now();
|
|
||||||
const tasks = new Map<string, ActiveTaskInfo>();
|
|
||||||
// Filter out expired tasks
|
|
||||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
|
||||||
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
|
||||||
tasks.set(sessionId, task);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tasks;
|
|
||||||
} catch {
|
|
||||||
return new Map();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Save active tasks to localStorage */
|
|
||||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
|
||||||
if (typeof window === "undefined") return;
|
|
||||||
try {
|
|
||||||
const obj: Record<string, ActiveTaskInfo> = {};
|
|
||||||
for (const [sessionId, task] of tasks) {
|
|
||||||
obj[sessionId] = task;
|
|
||||||
}
|
|
||||||
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
|
||||||
} catch {
|
|
||||||
// Ignore storage errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ChatStoreState {
|
interface ChatStoreState {
|
||||||
activeStreams: Map<string, ActiveStream>;
|
activeStreams: Map<string, ActiveStream>;
|
||||||
completedStreams: Map<string, StreamResult>;
|
completedStreams: Map<string, StreamResult>;
|
||||||
activeSessions: Set<string>;
|
activeSessions: Set<string>;
|
||||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
|
||||||
activeTasks: Map<string, ActiveTaskInfo>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatStoreActions {
|
interface ChatStoreActions {
|
||||||
@@ -91,24 +41,6 @@ interface ChatStoreActions {
|
|||||||
unregisterActiveSession: (sessionId: string) => void;
|
unregisterActiveSession: (sessionId: string) => void;
|
||||||
isSessionActive: (sessionId: string) => boolean;
|
isSessionActive: (sessionId: string) => boolean;
|
||||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||||
/** Track active task for SSE reconnection */
|
|
||||||
setActiveTask: (
|
|
||||||
sessionId: string,
|
|
||||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
|
||||||
) => void;
|
|
||||||
/** Get active task for a session */
|
|
||||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
|
||||||
/** Clear active task when operation completes */
|
|
||||||
clearActiveTask: (sessionId: string) => void;
|
|
||||||
/** Reconnect to an existing task stream */
|
|
||||||
reconnectToTask: (
|
|
||||||
sessionId: string,
|
|
||||||
taskId: string,
|
|
||||||
lastMessageId?: string,
|
|
||||||
onChunk?: (chunk: StreamChunk) => void,
|
|
||||||
) => Promise<void>;
|
|
||||||
/** Update last message ID for a task (for tracking replay position) */
|
|
||||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||||
@@ -132,126 +64,18 @@ function cleanupExpiredStreams(
|
|||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const cleaned = new Map(completedStreams);
|
const cleaned = new Map(completedStreams);
|
||||||
for (const [sessionId, result] of cleaned) {
|
for (const [sessionId, result] of cleaned) {
|
||||||
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||||
cleaned.delete(sessionId);
|
cleaned.delete(sessionId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cleaned;
|
return cleaned;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Finalize a stream by moving it from activeStreams to completedStreams.
|
|
||||||
* Also handles cleanup and notifications.
|
|
||||||
*/
|
|
||||||
function finalizeStream(
|
|
||||||
sessionId: string,
|
|
||||||
stream: ActiveStream,
|
|
||||||
onChunk: ((chunk: StreamChunk) => void) | undefined,
|
|
||||||
get: () => ChatStoreState & ChatStoreActions,
|
|
||||||
set: (state: Partial<ChatStoreState>) => void,
|
|
||||||
): void {
|
|
||||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
|
||||||
|
|
||||||
if (stream.status !== "streaming") {
|
|
||||||
const currentState = get();
|
|
||||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
|
||||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
|
||||||
|
|
||||||
const storedStream = finalActiveStreams.get(sessionId);
|
|
||||||
if (storedStream === stream) {
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: stream.status,
|
|
||||||
chunks: stream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: stream.error,
|
|
||||||
};
|
|
||||||
finalCompletedStreams.set(sessionId, result);
|
|
||||||
finalActiveStreams.delete(sessionId);
|
|
||||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
|
||||||
set({
|
|
||||||
activeStreams: finalActiveStreams,
|
|
||||||
completedStreams: finalCompletedStreams,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up an existing stream for a session and move it to completed streams.
|
|
||||||
* Returns updated maps for both active and completed streams.
|
|
||||||
*/
|
|
||||||
function cleanupExistingStream(
|
|
||||||
sessionId: string,
|
|
||||||
activeStreams: Map<string, ActiveStream>,
|
|
||||||
completedStreams: Map<string, StreamResult>,
|
|
||||||
callbacks: Set<StreamCompleteCallback>,
|
|
||||||
): {
|
|
||||||
activeStreams: Map<string, ActiveStream>;
|
|
||||||
completedStreams: Map<string, StreamResult>;
|
|
||||||
} {
|
|
||||||
const newActiveStreams = new Map(activeStreams);
|
|
||||||
let newCompletedStreams = new Map(completedStreams);
|
|
||||||
|
|
||||||
const existingStream = newActiveStreams.get(sessionId);
|
|
||||||
if (existingStream) {
|
|
||||||
existingStream.abortController.abort();
|
|
||||||
const normalizedStatus =
|
|
||||||
existingStream.status === "streaming"
|
|
||||||
? "completed"
|
|
||||||
: existingStream.status;
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: normalizedStatus,
|
|
||||||
chunks: existingStream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: existingStream.error,
|
|
||||||
};
|
|
||||||
newCompletedStreams.set(sessionId, result);
|
|
||||||
newActiveStreams.delete(sessionId);
|
|
||||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
|
||||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
|
||||||
notifyStreamComplete(callbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new active stream with initial state.
|
|
||||||
*/
|
|
||||||
function createActiveStream(
|
|
||||||
sessionId: string,
|
|
||||||
onChunk?: (chunk: StreamChunk) => void,
|
|
||||||
): ActiveStream {
|
|
||||||
const abortController = new AbortController();
|
|
||||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
|
||||||
if (onChunk) initialCallbacks.add(onChunk);
|
|
||||||
|
|
||||||
return {
|
|
||||||
sessionId,
|
|
||||||
abortController,
|
|
||||||
status: "streaming",
|
|
||||||
startedAt: Date.now(),
|
|
||||||
chunks: [],
|
|
||||||
onChunkCallbacks: initialCallbacks,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||||
activeStreams: new Map(),
|
activeStreams: new Map(),
|
||||||
completedStreams: new Map(),
|
completedStreams: new Map(),
|
||||||
activeSessions: new Set(),
|
activeSessions: new Set(),
|
||||||
streamCompleteCallbacks: new Set(),
|
streamCompleteCallbacks: new Set(),
|
||||||
activeTasks: loadPersistedTasks(),
|
|
||||||
|
|
||||||
startStream: async function startStream(
|
startStream: async function startStream(
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -261,21 +85,45 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunk,
|
onChunk,
|
||||||
) {
|
) {
|
||||||
const state = get();
|
const state = get();
|
||||||
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
const callbacks = state.streamCompleteCallbacks;
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
// Clean up any existing stream for this session
|
const existingStream = newActiveStreams.get(sessionId);
|
||||||
const {
|
if (existingStream) {
|
||||||
activeStreams: newActiveStreams,
|
existingStream.abortController.abort();
|
||||||
completedStreams: newCompletedStreams,
|
const normalizedStatus =
|
||||||
} = cleanupExistingStream(
|
existingStream.status === "streaming"
|
||||||
sessionId,
|
? "completed"
|
||||||
state.activeStreams,
|
: existingStream.status;
|
||||||
state.completedStreams,
|
const result: StreamResult = {
|
||||||
callbacks,
|
sessionId,
|
||||||
);
|
status: normalizedStatus,
|
||||||
|
chunks: existingStream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: existingStream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||||
|
notifyStreamComplete(callbacks, sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||||
|
if (onChunk) initialCallbacks.add(onChunk);
|
||||||
|
|
||||||
|
const stream: ActiveStream = {
|
||||||
|
sessionId,
|
||||||
|
abortController,
|
||||||
|
status: "streaming",
|
||||||
|
startedAt: Date.now(),
|
||||||
|
chunks: [],
|
||||||
|
onChunkCallbacks: initialCallbacks,
|
||||||
|
};
|
||||||
|
|
||||||
// Create new stream
|
|
||||||
const stream = createActiveStream(sessionId, onChunk);
|
|
||||||
newActiveStreams.set(sessionId, stream);
|
newActiveStreams.set(sessionId, stream);
|
||||||
set({
|
set({
|
||||||
activeStreams: newActiveStreams,
|
activeStreams: newActiveStreams,
|
||||||
@@ -285,7 +133,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
try {
|
try {
|
||||||
await executeStream(stream, message, isUserMessage, context);
|
await executeStream(stream, message, isUserMessage, context);
|
||||||
} finally {
|
} finally {
|
||||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||||
|
if (stream.status !== "streaming") {
|
||||||
|
const currentState = get();
|
||||||
|
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||||
|
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||||
|
|
||||||
|
const storedStream = finalActiveStreams.get(sessionId);
|
||||||
|
if (storedStream === stream) {
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
finalCompletedStreams.set(sessionId, result);
|
||||||
|
finalActiveStreams.delete(sessionId);
|
||||||
|
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||||
|
set({
|
||||||
|
activeStreams: finalActiveStreams,
|
||||||
|
completedStreams: finalCompletedStreams,
|
||||||
|
});
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
notifyStreamComplete(
|
||||||
|
currentState.streamCompleteCallbacks,
|
||||||
|
sessionId,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -409,93 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|
||||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
|
||||||
const state = get();
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.set(sessionId, {
|
|
||||||
...taskInfo,
|
|
||||||
sessionId,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
});
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
|
|
||||||
getActiveTask: function getActiveTask(sessionId) {
|
|
||||||
return get().activeTasks.get(sessionId);
|
|
||||||
},
|
|
||||||
|
|
||||||
clearActiveTask: function clearActiveTask(sessionId) {
|
|
||||||
const state = get();
|
|
||||||
if (!state.activeTasks.has(sessionId)) return;
|
|
||||||
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.delete(sessionId);
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
|
|
||||||
reconnectToTask: async function reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
taskId,
|
|
||||||
lastMessageId = INITIAL_STREAM_ID,
|
|
||||||
onChunk,
|
|
||||||
) {
|
|
||||||
const state = get();
|
|
||||||
const callbacks = state.streamCompleteCallbacks;
|
|
||||||
|
|
||||||
// Clean up any existing stream for this session
|
|
||||||
const {
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
} = cleanupExistingStream(
|
|
||||||
sessionId,
|
|
||||||
state.activeStreams,
|
|
||||||
state.completedStreams,
|
|
||||||
callbacks,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Create new stream for reconnection
|
|
||||||
const stream = createActiveStream(sessionId, onChunk);
|
|
||||||
newActiveStreams.set(sessionId, stream);
|
|
||||||
set({
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
|
||||||
} finally {
|
|
||||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
|
||||||
|
|
||||||
// Clear active task on completion
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
const taskState = get();
|
|
||||||
if (taskState.activeTasks.has(sessionId)) {
|
|
||||||
const newActiveTasks = new Map(taskState.activeTasks);
|
|
||||||
newActiveTasks.delete(sessionId);
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
|
||||||
sessionId,
|
|
||||||
lastMessageId,
|
|
||||||
) {
|
|
||||||
const state = get();
|
|
||||||
const task = state.activeTasks.get(sessionId);
|
|
||||||
if (!task) return;
|
|
||||||
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.set(sessionId, {
|
|
||||||
...task,
|
|
||||||
lastMessageId,
|
|
||||||
});
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
|||||||
|
|
||||||
export interface StreamChunk {
|
export interface StreamChunk {
|
||||||
type:
|
type:
|
||||||
| "stream_start"
|
|
||||||
| "text_chunk"
|
| "text_chunk"
|
||||||
| "text_ended"
|
| "text_ended"
|
||||||
| "tool_call"
|
| "tool_call"
|
||||||
@@ -16,7 +15,6 @@ export interface StreamChunk {
|
|||||||
| "error"
|
| "error"
|
||||||
| "usage"
|
| "usage"
|
||||||
| "stream_end";
|
| "stream_end";
|
||||||
taskId?: string;
|
|
||||||
timestamp?: string;
|
timestamp?: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
message?: string;
|
message?: string;
|
||||||
@@ -43,7 +41,7 @@ export interface StreamChunk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type VercelStreamChunk =
|
export type VercelStreamChunk =
|
||||||
| { type: "start"; messageId: string; taskId?: string }
|
| { type: "start"; messageId: string }
|
||||||
| { type: "finish" }
|
| { type: "finish" }
|
||||||
| { type: "text-start"; id: string }
|
| { type: "text-start"; id: string }
|
||||||
| { type: "text-delta"; id: string; delta: string }
|
| { type: "text-delta"; id: string; delta: string }
|
||||||
@@ -94,70 +92,3 @@ export interface StreamResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||||
|
|
||||||
// Type guards for message types
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has a toolId property.
|
|
||||||
*/
|
|
||||||
export function hasToolId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { toolId: string } {
|
|
||||||
return (
|
|
||||||
"toolId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).toolId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has an operationId property.
|
|
||||||
*/
|
|
||||||
export function hasOperationId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { operationId: string } {
|
|
||||||
return (
|
|
||||||
"operationId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).operationId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has a toolCallId property.
|
|
||||||
*/
|
|
||||||
export function hasToolCallId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { toolCallId: string } {
|
|
||||||
return (
|
|
||||||
"toolCallId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message is an operation message type.
|
|
||||||
*/
|
|
||||||
export function isOperationMessage<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & {
|
|
||||||
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
|
||||||
} {
|
|
||||||
return (
|
|
||||||
msg.type === "operation_started" ||
|
|
||||||
msg.type === "operation_pending" ||
|
|
||||||
msg.type === "operation_in_progress"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the tool ID from a message if available.
|
|
||||||
* Checks toolId, operationId, and toolCallId properties.
|
|
||||||
*/
|
|
||||||
export function getToolIdFromMessage<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): string | undefined {
|
|
||||||
const record = msg as Record<string, unknown>;
|
|
||||||
if (typeof record.toolId === "string") return record.toolId;
|
|
||||||
if (typeof record.operationId === "string") return record.operationId;
|
|
||||||
if (typeof record.toolCallId === "string") return record.toolCallId;
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
@@ -16,13 +17,6 @@ export interface ChatContainerProps {
|
|||||||
className?: string;
|
className?: string;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
/** Active stream info from the server for reconnection */
|
|
||||||
activeStream?: {
|
|
||||||
taskId: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -32,7 +26,6 @@ export function ChatContainer({
|
|||||||
className,
|
className,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -48,13 +41,16 @@ export function ChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
onStreamingChange?.(isStreaming);
|
onStreamingChange?.(isStreaming);
|
||||||
}, [isStreaming, onStreamingChange]);
|
}, [isStreaming, onStreamingChange]);
|
||||||
|
|
||||||
|
const breakpoint = useBreakpoint();
|
||||||
|
const isMobile =
|
||||||
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -122,7 +118,11 @@ export function ChatContainer({
|
|||||||
disabled={isStreaming || !sessionId}
|
disabled={isStreaming || !sessionId}
|
||||||
isStreaming={isStreaming}
|
isStreaming={isStreaming}
|
||||||
onStop={stopStreaming}
|
onStop={stopStreaming}
|
||||||
placeholder="What else can I help with?"
|
placeholder={
|
||||||
|
isMobile
|
||||||
|
? "You can search or just ask"
|
||||||
|
: 'You can search or just ask — e.g. "create a blog post outline"'
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import { toast } from "sonner";
|
|||||||
import type { StreamChunk } from "../../chat-types";
|
import type { StreamChunk } from "../../chat-types";
|
||||||
import type { HandlerDependencies } from "./handlers";
|
import type { HandlerDependencies } from "./handlers";
|
||||||
import {
|
import {
|
||||||
getErrorDisplayMessage,
|
|
||||||
handleError,
|
handleError,
|
||||||
handleLoginNeeded,
|
handleLoginNeeded,
|
||||||
handleStreamEnd,
|
handleStreamEnd,
|
||||||
@@ -25,22 +24,16 @@ export function createStreamEventDispatcher(
|
|||||||
chunk.type === "need_login" ||
|
chunk.type === "need_login" ||
|
||||||
chunk.type === "error"
|
chunk.type === "error"
|
||||||
) {
|
) {
|
||||||
|
if (!deps.hasResponseRef.current) {
|
||||||
|
console.info("[ChatStream] First response chunk:", {
|
||||||
|
type: chunk.type,
|
||||||
|
sessionId: deps.sessionId,
|
||||||
|
});
|
||||||
|
}
|
||||||
deps.hasResponseRef.current = true;
|
deps.hasResponseRef.current = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "stream_start":
|
|
||||||
// Store task ID for SSE reconnection
|
|
||||||
if (chunk.taskId && deps.onActiveTaskStarted) {
|
|
||||||
deps.onActiveTaskStarted({
|
|
||||||
taskId: chunk.taskId,
|
|
||||||
operationId: chunk.taskId,
|
|
||||||
toolName: "chat",
|
|
||||||
toolCallId: "chat_stream",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "text_chunk":
|
case "text_chunk":
|
||||||
handleTextChunk(chunk, deps);
|
handleTextChunk(chunk, deps);
|
||||||
break;
|
break;
|
||||||
@@ -63,7 +56,11 @@ export function createStreamEventDispatcher(
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case "stream_end":
|
case "stream_end":
|
||||||
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
|
console.info("[ChatStream] Stream ended:", {
|
||||||
|
sessionId: deps.sessionId,
|
||||||
|
hasResponse: deps.hasResponseRef.current,
|
||||||
|
chunkCount: deps.streamingChunksRef.current.length,
|
||||||
|
});
|
||||||
handleStreamEnd(chunk, deps);
|
handleStreamEnd(chunk, deps);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -73,7 +70,7 @@ export function createStreamEventDispatcher(
|
|||||||
// Show toast at dispatcher level to avoid circular dependencies
|
// Show toast at dispatcher level to avoid circular dependencies
|
||||||
if (!isRegionBlocked) {
|
if (!isRegionBlocked) {
|
||||||
toast.error("Chat Error", {
|
toast.error("Chat Error", {
|
||||||
description: getErrorDisplayMessage(chunk),
|
description: chunk.message || chunk.content || "An error occurred",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -18,19 +18,11 @@ export interface HandlerDependencies {
|
|||||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||||
streamingChunksRef: MutableRefObject<string[]>;
|
streamingChunksRef: MutableRefObject<string[]>;
|
||||||
hasResponseRef: MutableRefObject<boolean>;
|
hasResponseRef: MutableRefObject<boolean>;
|
||||||
textFinalizedRef: MutableRefObject<boolean>;
|
|
||||||
streamEndedRef: MutableRefObject<boolean>;
|
|
||||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
onActiveTaskStarted?: (taskInfo: {
|
|
||||||
taskId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
toolCallId: string;
|
|
||||||
}) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||||
@@ -40,25 +32,6 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
|||||||
return message.toLowerCase().includes("not available in your region");
|
return message.toLowerCase().includes("not available in your region");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getUserFriendlyErrorMessage(
|
|
||||||
code: string | undefined,
|
|
||||||
): string | undefined {
|
|
||||||
switch (code) {
|
|
||||||
case "TASK_EXPIRED":
|
|
||||||
return "This operation has expired. Please try again.";
|
|
||||||
case "TASK_NOT_FOUND":
|
|
||||||
return "Could not find the requested operation.";
|
|
||||||
case "ACCESS_DENIED":
|
|
||||||
return "You do not have access to this operation.";
|
|
||||||
case "QUEUE_OVERFLOW":
|
|
||||||
return "Connection was interrupted. Please refresh to continue.";
|
|
||||||
case "MODEL_NOT_AVAILABLE_REGION":
|
|
||||||
return "This model is not available in your region.";
|
|
||||||
default:
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
if (!chunk.content) return;
|
if (!chunk.content) return;
|
||||||
deps.setHasTextChunks(true);
|
deps.setHasTextChunks(true);
|
||||||
@@ -73,15 +46,10 @@ export function handleTextEnded(
|
|||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
if (deps.textFinalizedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
deps.textFinalizedRef.current = true;
|
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -108,14 +76,9 @@ export function handleToolCallStart(
|
|||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
// Use deterministic fallback instead of Date.now() to ensure same ID on replay
|
|
||||||
const toolId =
|
|
||||||
chunk.tool_id ||
|
|
||||||
`tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`;
|
|
||||||
|
|
||||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||||
type: "tool_call",
|
type: "tool_call",
|
||||||
toolId,
|
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||||
toolName: chunk.tool_name || "Executing",
|
toolName: chunk.tool_name || "Executing",
|
||||||
arguments: chunk.arguments || {},
|
arguments: chunk.arguments || {},
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
@@ -148,29 +111,6 @@ export function handleToolCallStart(
|
|||||||
deps.setMessages(updateToolCallMessages);
|
deps.setMessages(updateToolCallMessages);
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOOL_RESPONSE_TYPES = new Set([
|
|
||||||
"tool_response",
|
|
||||||
"operation_started",
|
|
||||||
"operation_pending",
|
|
||||||
"operation_in_progress",
|
|
||||||
"execution_started",
|
|
||||||
"agent_carousel",
|
|
||||||
"clarification_needed",
|
|
||||||
]);
|
|
||||||
|
|
||||||
function hasResponseForTool(
|
|
||||||
messages: ChatMessageData[],
|
|
||||||
toolId: string,
|
|
||||||
): boolean {
|
|
||||||
return messages.some((msg) => {
|
|
||||||
if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false;
|
|
||||||
const msgToolId =
|
|
||||||
(msg as { toolId?: string }).toolId ||
|
|
||||||
(msg as { toolCallId?: string }).toolCallId;
|
|
||||||
return msgToolId === toolId;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export function handleToolResponse(
|
export function handleToolResponse(
|
||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
@@ -212,49 +152,31 @@ export function handleToolResponse(
|
|||||||
) {
|
) {
|
||||||
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
||||||
if (inputsMessage) {
|
if (inputsMessage) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, inputsMessage]);
|
||||||
// Check for duplicate inputs_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "inputs_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, inputsMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
const credentialsMessage = extractCredentialsNeeded(
|
const credentialsMessage = extractCredentialsNeeded(
|
||||||
parsedResult,
|
parsedResult,
|
||||||
chunk.tool_name,
|
chunk.tool_name,
|
||||||
);
|
);
|
||||||
if (credentialsMessage) {
|
if (credentialsMessage) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, credentialsMessage]);
|
||||||
// Check for duplicate credentials_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "credentials_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, credentialsMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Trigger polling when operation_started is received
|
||||||
if (responseMessage.type === "operation_started") {
|
if (responseMessage.type === "operation_started") {
|
||||||
deps.onOperationStarted?.();
|
deps.onOperationStarted?.();
|
||||||
const taskId = (responseMessage as { taskId?: string }).taskId;
|
|
||||||
if (taskId && deps.onActiveTaskStarted) {
|
|
||||||
deps.onActiveTaskStarted({
|
|
||||||
taskId,
|
|
||||||
operationId:
|
|
||||||
(responseMessage as { operationId?: string }).operationId || "",
|
|
||||||
toolName: (responseMessage as { toolName?: string }).toolName || "",
|
|
||||||
toolCallId: (responseMessage as { toolId?: string }).toolId || "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
);
|
);
|
||||||
if (hasResponseForTool(prev, chunk.tool_id!)) {
|
const hasResponse = prev.some(
|
||||||
return prev;
|
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
||||||
}
|
);
|
||||||
|
if (hasResponse) return prev;
|
||||||
if (toolCallIndex !== -1) {
|
if (toolCallIndex !== -1) {
|
||||||
const newMessages = [...prev];
|
const newMessages = [...prev];
|
||||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||||
@@ -276,48 +198,28 @@ export function handleLoginNeeded(
|
|||||||
agentInfo: chunk.agent_info,
|
agentInfo: chunk.agent_info,
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, loginNeededMessage]);
|
||||||
// Check for duplicate login_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "login_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, loginNeededMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleStreamEnd(
|
export function handleStreamEnd(
|
||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
if (deps.streamEndedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
deps.streamEndedRef.current = true;
|
|
||||||
|
|
||||||
const completedContent = deps.streamingChunksRef.current.join("");
|
const completedContent = deps.streamingChunksRef.current.join("");
|
||||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [
|
||||||
const exists = prev.some(
|
...prev,
|
||||||
(msg) =>
|
{
|
||||||
msg.type === "message" &&
|
type: "message",
|
||||||
msg.role === "assistant" &&
|
role: "assistant",
|
||||||
msg.content === "No response received. Please try again.",
|
content: "No response received. Please try again.",
|
||||||
);
|
timestamp: new Date(),
|
||||||
if (exists) return prev;
|
},
|
||||||
return [
|
]);
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
type: "message",
|
|
||||||
role: "assistant",
|
|
||||||
content: "No response received. Please try again.",
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
if (completedContent.trim() && !deps.textFinalizedRef.current) {
|
if (completedContent.trim()) {
|
||||||
deps.textFinalizedRef.current = true;
|
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -342,6 +244,8 @@ export function handleStreamEnd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
|
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||||
|
console.error("Stream error:", errorMessage);
|
||||||
if (isRegionBlockedError(chunk)) {
|
if (isRegionBlockedError(chunk)) {
|
||||||
deps.setIsRegionBlockedModalOpen(true);
|
deps.setIsRegionBlockedModalOpen(true);
|
||||||
}
|
}
|
||||||
@@ -349,14 +253,4 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
|||||||
deps.setHasTextChunks(false);
|
deps.setHasTextChunks(false);
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
deps.textFinalizedRef.current = false;
|
|
||||||
deps.streamEndedRef.current = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getErrorDisplayMessage(chunk: StreamChunk): string {
|
|
||||||
const friendlyMessage = getUserFriendlyErrorMessage(chunk.code);
|
|
||||||
if (friendlyMessage) {
|
|
||||||
return friendlyMessage;
|
|
||||||
}
|
|
||||||
return chunk.message || chunk.content || "An error occurred";
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -349,7 +349,6 @@ export function parseToolResponse(
|
|||||||
toolName: (parsedResult.tool_name as string) || toolName,
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
toolId,
|
toolId,
|
||||||
operationId: (parsedResult.operation_id as string) || "",
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
|
||||||
message:
|
message:
|
||||||
(parsedResult.message as string) ||
|
(parsedResult.message as string) ||
|
||||||
"Operation started. You can close this tab.",
|
"Operation started. You can close this tab.",
|
||||||
|
|||||||
@@ -1,17 +1,10 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { INITIAL_STREAM_ID } from "../../chat-constants";
|
|
||||||
import { useChatStore } from "../../chat-store";
|
import { useChatStore } from "../../chat-store";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useChatStream } from "../../useChatStream";
|
import { useChatStream } from "../../useChatStream";
|
||||||
import { usePageContext } from "../../usePageContext";
|
import { usePageContext } from "../../usePageContext";
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
import {
|
|
||||||
getToolIdFromMessage,
|
|
||||||
hasToolId,
|
|
||||||
isOperationMessage,
|
|
||||||
type StreamChunk,
|
|
||||||
} from "../../chat-types";
|
|
||||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||||
import {
|
import {
|
||||||
createUserMessage,
|
createUserMessage,
|
||||||
@@ -21,13 +14,6 @@ import {
|
|||||||
processInitialMessages,
|
processInitialMessages,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
const TOOL_RESULT_TYPES = new Set([
|
|
||||||
"tool_response",
|
|
||||||
"agent_carousel",
|
|
||||||
"execution_started",
|
|
||||||
"clarification_needed",
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Helper to generate deduplication key for a message
|
// Helper to generate deduplication key for a message
|
||||||
function getMessageKey(msg: ChatMessageData): string {
|
function getMessageKey(msg: ChatMessageData): string {
|
||||||
if (msg.type === "message") {
|
if (msg.type === "message") {
|
||||||
@@ -37,18 +23,14 @@ function getMessageKey(msg: ChatMessageData): string {
|
|||||||
return `msg:${msg.role}:${msg.content}`;
|
return `msg:${msg.role}:${msg.content}`;
|
||||||
} else if (msg.type === "tool_call") {
|
} else if (msg.type === "tool_call") {
|
||||||
return `toolcall:${msg.toolId}`;
|
return `toolcall:${msg.toolId}`;
|
||||||
} else if (TOOL_RESULT_TYPES.has(msg.type)) {
|
} else if (msg.type === "tool_response") {
|
||||||
// Unified key for all tool result types - same toolId with different types
|
return `toolresponse:${(msg as any).toolId}`;
|
||||||
// (tool_response vs agent_carousel) should deduplicate to the same key
|
} else if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
// If no toolId, fall back to content-based key to avoid empty key collisions
|
msg.type === "operation_pending" ||
|
||||||
if (!toolId) {
|
msg.type === "operation_in_progress"
|
||||||
return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`;
|
) {
|
||||||
}
|
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||||
return `toolresult:${toolId}`;
|
|
||||||
} else if (isOperationMessage(msg)) {
|
|
||||||
const toolId = getToolIdFromMessage(msg) || "";
|
|
||||||
return `op:${toolId}:${msg.toolName}`;
|
|
||||||
} else {
|
} else {
|
||||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||||
}
|
}
|
||||||
@@ -59,13 +41,6 @@ interface Args {
|
|||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
/** Active stream info from the server for reconnection */
|
|
||||||
activeStream?: {
|
|
||||||
taskId: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({
|
||||||
@@ -73,7 +48,6 @@ export function useChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
@@ -83,8 +57,6 @@ export function useChatContainer({
|
|||||||
useState(false);
|
useState(false);
|
||||||
const hasResponseRef = useRef(false);
|
const hasResponseRef = useRef(false);
|
||||||
const streamingChunksRef = useRef<string[]>([]);
|
const streamingChunksRef = useRef<string[]>([]);
|
||||||
const textFinalizedRef = useRef(false);
|
|
||||||
const streamEndedRef = useRef(false);
|
|
||||||
const previousSessionIdRef = useRef<string | null>(null);
|
const previousSessionIdRef = useRef<string | null>(null);
|
||||||
const {
|
const {
|
||||||
error,
|
error,
|
||||||
@@ -93,182 +65,44 @@ export function useChatContainer({
|
|||||||
} = useChatStream();
|
} = useChatStream();
|
||||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||||
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
|
||||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
|
||||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
|
||||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||||
// Track whether we've already connected to this activeStream to avoid duplicate connections
|
|
||||||
const connectedActiveStreamRef = useRef<string | null>(null);
|
|
||||||
// Track if component is mounted to prevent state updates after unmount
|
|
||||||
const isMountedRef = useRef(true);
|
|
||||||
// Track current dispatcher to prevent multiple dispatchers from adding messages
|
|
||||||
const currentDispatcherIdRef = useRef(0);
|
|
||||||
|
|
||||||
// Set mounted flag - reset on every mount, cleanup on unmount
|
|
||||||
useEffect(function trackMountedState() {
|
|
||||||
isMountedRef.current = true;
|
|
||||||
return function cleanup() {
|
|
||||||
isMountedRef.current = false;
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Callback to store active task info for SSE reconnection
|
|
||||||
function handleActiveTaskStarted(taskInfo: {
|
|
||||||
taskId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
toolCallId: string;
|
|
||||||
}) {
|
|
||||||
if (!sessionId) return;
|
|
||||||
setActiveTask(sessionId, {
|
|
||||||
taskId: taskInfo.taskId,
|
|
||||||
operationId: taskInfo.operationId,
|
|
||||||
toolName: taskInfo.toolName,
|
|
||||||
lastMessageId: INITIAL_STREAM_ID,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create dispatcher for stream events - stable reference for current sessionId
|
|
||||||
// Each dispatcher gets a unique ID to prevent stale dispatchers from updating state
|
|
||||||
function createDispatcher() {
|
|
||||||
if (!sessionId) return () => {};
|
|
||||||
// Increment dispatcher ID - only the most recent dispatcher should update state
|
|
||||||
const dispatcherId = ++currentDispatcherIdRef.current;
|
|
||||||
|
|
||||||
const baseDispatcher = createStreamEventDispatcher({
|
|
||||||
setHasTextChunks,
|
|
||||||
setStreamingChunks,
|
|
||||||
streamingChunksRef,
|
|
||||||
hasResponseRef,
|
|
||||||
textFinalizedRef,
|
|
||||||
streamEndedRef,
|
|
||||||
setMessages,
|
|
||||||
setIsRegionBlockedModalOpen,
|
|
||||||
sessionId,
|
|
||||||
setIsStreamingInitiated,
|
|
||||||
onOperationStarted,
|
|
||||||
onActiveTaskStarted: handleActiveTaskStarted,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wrap dispatcher to check if it's still the current one
|
|
||||||
return function guardedDispatcher(chunk: StreamChunk) {
|
|
||||||
// Skip if component unmounted or this is a stale dispatcher
|
|
||||||
if (!isMountedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (dispatcherId !== currentDispatcherIdRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
baseDispatcher(chunk);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
useEffect(
|
useEffect(
|
||||||
function handleSessionChange() {
|
function handleSessionChange() {
|
||||||
const isSessionChange = sessionId !== previousSessionIdRef.current;
|
if (sessionId === previousSessionIdRef.current) return;
|
||||||
|
|
||||||
// Handle session change - reset state
|
const prevSession = previousSessionIdRef.current;
|
||||||
if (isSessionChange) {
|
if (prevSession) {
|
||||||
const prevSession = previousSessionIdRef.current;
|
stopStreaming(prevSession);
|
||||||
if (prevSession) {
|
|
||||||
stopStreaming(prevSession);
|
|
||||||
}
|
|
||||||
previousSessionIdRef.current = sessionId;
|
|
||||||
connectedActiveStreamRef.current = null;
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
setIsStreamingInitiated(false);
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
}
|
}
|
||||||
|
previousSessionIdRef.current = sessionId;
|
||||||
|
setMessages([]);
|
||||||
|
setStreamingChunks([]);
|
||||||
|
streamingChunksRef.current = [];
|
||||||
|
setHasTextChunks(false);
|
||||||
|
setIsStreamingInitiated(false);
|
||||||
|
hasResponseRef.current = false;
|
||||||
|
|
||||||
if (!sessionId) return;
|
if (!sessionId) return;
|
||||||
|
|
||||||
// Priority 1: Check if server told us there's an active stream (most authoritative)
|
const activeStream = activeStreams.get(sessionId);
|
||||||
if (activeStream) {
|
if (!activeStream || activeStream.status !== "streaming") return;
|
||||||
const streamKey = `${sessionId}:${activeStream.taskId}`;
|
|
||||||
|
|
||||||
if (connectedActiveStreamRef.current === streamKey) {
|
const dispatcher = createStreamEventDispatcher({
|
||||||
return;
|
setHasTextChunks,
|
||||||
}
|
setStreamingChunks,
|
||||||
|
streamingChunksRef,
|
||||||
// Skip if there's already an active stream for this session in the store
|
hasResponseRef,
|
||||||
const existingStream = activeStreams.get(sessionId);
|
setMessages,
|
||||||
if (existingStream && existingStream.status === "streaming") {
|
setIsRegionBlockedModalOpen,
|
||||||
connectedActiveStreamRef.current = streamKey;
|
sessionId,
|
||||||
return;
|
setIsStreamingInitiated,
|
||||||
}
|
onOperationStarted,
|
||||||
|
});
|
||||||
connectedActiveStreamRef.current = streamKey;
|
|
||||||
|
|
||||||
// Clear all state before reconnection to prevent duplicates
|
|
||||||
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
|
||||||
setActiveTask(sessionId, {
|
|
||||||
taskId: activeStream.taskId,
|
|
||||||
operationId: activeStream.operationId,
|
|
||||||
toolName: activeStream.toolName,
|
|
||||||
lastMessageId: activeStream.lastMessageId,
|
|
||||||
});
|
|
||||||
reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeStream.taskId,
|
|
||||||
activeStream.lastMessageId,
|
|
||||||
createDispatcher(),
|
|
||||||
);
|
|
||||||
// Don't return cleanup here - the guarded dispatcher handles stale events
|
|
||||||
// and the stream will complete naturally. Cleanup would prematurely stop
|
|
||||||
// the stream when effect re-runs due to activeStreams changing.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only check localStorage/in-memory on session change
|
|
||||||
if (!isSessionChange) return;
|
|
||||||
|
|
||||||
// Priority 2: Check localStorage for active task
|
|
||||||
const activeTask = getActiveTask(sessionId);
|
|
||||||
if (activeTask) {
|
|
||||||
// Clear all state before reconnection to prevent duplicates
|
|
||||||
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
|
||||||
reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeTask.taskId,
|
|
||||||
activeTask.lastMessageId,
|
|
||||||
createDispatcher(),
|
|
||||||
);
|
|
||||||
// Don't return cleanup here - the guarded dispatcher handles stale events
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 3: Check for an in-memory active stream (same-tab scenario)
|
|
||||||
const inMemoryStream = activeStreams.get(sessionId);
|
|
||||||
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
const skipReplay = initialMessages.length > 0;
|
const skipReplay = initialMessages.length > 0;
|
||||||
return subscribeToStream(sessionId, createDispatcher(), skipReplay);
|
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -276,10 +110,6 @@ export function useChatContainer({
|
|||||||
activeStreams,
|
activeStreams,
|
||||||
subscribeToStream,
|
subscribeToStream,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
getActiveTask,
|
|
||||||
reconnectToTask,
|
|
||||||
activeStream,
|
|
||||||
setActiveTask,
|
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -294,7 +124,7 @@ export function useChatContainer({
|
|||||||
msg.type === "agent_carousel" ||
|
msg.type === "agent_carousel" ||
|
||||||
msg.type === "execution_started"
|
msg.type === "execution_started"
|
||||||
) {
|
) {
|
||||||
const toolId = hasToolId(msg) ? msg.toolId : undefined;
|
const toolId = (msg as any).toolId;
|
||||||
if (toolId) {
|
if (toolId) {
|
||||||
ids.add(toolId);
|
ids.add(toolId);
|
||||||
}
|
}
|
||||||
@@ -311,8 +141,12 @@ export function useChatContainer({
|
|||||||
|
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const filtered = prev.filter((msg) => {
|
const filtered = prev.filter((msg) => {
|
||||||
if (isOperationMessage(msg)) {
|
if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false; // Remove - operation completed
|
return false; // Remove - operation completed
|
||||||
}
|
}
|
||||||
@@ -340,8 +174,12 @@ export function useChatContainer({
|
|||||||
// Filter local messages: remove duplicates and completed operation messages
|
// Filter local messages: remove duplicates and completed operation messages
|
||||||
const newLocalMessages = messages.filter((msg) => {
|
const newLocalMessages = messages.filter((msg) => {
|
||||||
// Remove operation messages for completed tools
|
// Remove operation messages for completed tools
|
||||||
if (isOperationMessage(msg)) {
|
if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -352,70 +190,7 @@ export function useChatContainer({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Server messages first (correct order), then new local messages
|
// Server messages first (correct order), then new local messages
|
||||||
const combined = [...processedInitial, ...newLocalMessages];
|
return [...processedInitial, ...newLocalMessages];
|
||||||
|
|
||||||
// Post-processing: Remove duplicate assistant messages that can occur during
|
|
||||||
// race conditions (e.g., rapid screen switching during SSE reconnection).
|
|
||||||
// Two assistant messages are considered duplicates if:
|
|
||||||
// - They are both text messages with role "assistant"
|
|
||||||
// - One message's content starts with the other's content (partial vs complete)
|
|
||||||
// - Or they have very similar content (>80% overlap at the start)
|
|
||||||
const deduplicated: ChatMessageData[] = [];
|
|
||||||
for (let i = 0; i < combined.length; i++) {
|
|
||||||
const current = combined[i];
|
|
||||||
|
|
||||||
// Check if this is an assistant text message
|
|
||||||
if (current.type !== "message" || current.role !== "assistant") {
|
|
||||||
deduplicated.push(current);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for duplicate assistant messages in the rest of the array
|
|
||||||
let dominated = false;
|
|
||||||
for (let j = 0; j < combined.length; j++) {
|
|
||||||
if (i === j) continue;
|
|
||||||
const other = combined[j];
|
|
||||||
if (other.type !== "message" || other.role !== "assistant") continue;
|
|
||||||
|
|
||||||
const currentContent = current.content || "";
|
|
||||||
const otherContent = other.content || "";
|
|
||||||
|
|
||||||
// Skip empty messages
|
|
||||||
if (!currentContent.trim() || !otherContent.trim()) continue;
|
|
||||||
|
|
||||||
// Check if current is a prefix of other (current is incomplete version)
|
|
||||||
if (
|
|
||||||
otherContent.length > currentContent.length &&
|
|
||||||
otherContent.startsWith(currentContent.slice(0, 100))
|
|
||||||
) {
|
|
||||||
// Current is a shorter/incomplete version of other - skip it
|
|
||||||
dominated = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if messages are nearly identical (within a small difference)
|
|
||||||
// This catches cases where content differs only slightly
|
|
||||||
const minLen = Math.min(currentContent.length, otherContent.length);
|
|
||||||
const compareLen = Math.min(minLen, 200); // Compare first 200 chars
|
|
||||||
if (
|
|
||||||
compareLen > 50 &&
|
|
||||||
currentContent.slice(0, compareLen) ===
|
|
||||||
otherContent.slice(0, compareLen)
|
|
||||||
) {
|
|
||||||
// Same prefix - keep the longer one
|
|
||||||
if (otherContent.length > currentContent.length) {
|
|
||||||
dominated = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!dominated) {
|
|
||||||
deduplicated.push(current);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return deduplicated;
|
|
||||||
}, [initialMessages, messages, completedToolIds]);
|
}, [initialMessages, messages, completedToolIds]);
|
||||||
|
|
||||||
async function sendMessage(
|
async function sendMessage(
|
||||||
@@ -423,8 +198,10 @@ export function useChatContainer({
|
|||||||
isUserMessage: boolean = true,
|
isUserMessage: boolean = true,
|
||||||
context?: { url: string; content: string },
|
context?: { url: string; content: string },
|
||||||
) {
|
) {
|
||||||
if (!sessionId) return;
|
if (!sessionId) {
|
||||||
|
console.error("[useChatContainer] Cannot send message: no session ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
setIsRegionBlockedModalOpen(false);
|
setIsRegionBlockedModalOpen(false);
|
||||||
if (isUserMessage) {
|
if (isUserMessage) {
|
||||||
const userMessage = createUserMessage(content);
|
const userMessage = createUserMessage(content);
|
||||||
@@ -437,19 +214,31 @@ export function useChatContainer({
|
|||||||
setHasTextChunks(false);
|
setHasTextChunks(false);
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
hasResponseRef.current = false;
|
hasResponseRef.current = false;
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
const dispatcher = createStreamEventDispatcher({
|
||||||
|
setHasTextChunks,
|
||||||
|
setStreamingChunks,
|
||||||
|
streamingChunksRef,
|
||||||
|
hasResponseRef,
|
||||||
|
setMessages,
|
||||||
|
setIsRegionBlockedModalOpen,
|
||||||
|
sessionId,
|
||||||
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await sendStreamMessage(
|
await sendStreamMessage(
|
||||||
sessionId,
|
sessionId,
|
||||||
content,
|
content,
|
||||||
createDispatcher(),
|
dispatcher,
|
||||||
isUserMessage,
|
isUserMessage,
|
||||||
context,
|
context,
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
console.error("[useChatContainer] Failed to send message:", err);
|
||||||
setIsStreamingInitiated(false);
|
setIsStreamingInitiated(false);
|
||||||
|
|
||||||
if (err instanceof Error && err.name === "AbortError") return;
|
if (err instanceof Error && err.name === "AbortError") return;
|
||||||
|
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
|
|||||||
@@ -74,20 +74,19 @@ export function ChatInput({
|
|||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{!value && !isRecording && (
|
|
||||||
<div
|
|
||||||
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
|
|
||||||
aria-hidden="true"
|
|
||||||
>
|
|
||||||
{isTranscribing ? "Transcribing..." : placeholder}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<textarea
|
<textarea
|
||||||
id={inputId}
|
id={inputId}
|
||||||
aria-label="Chat message input"
|
aria-label="Chat message input"
|
||||||
value={value}
|
value={value}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder={
|
||||||
|
isTranscribing
|
||||||
|
? "Transcribing..."
|
||||||
|
: isRecording
|
||||||
|
? ""
|
||||||
|
: placeholder
|
||||||
|
}
|
||||||
disabled={isInputDisabled}
|
disabled={isInputDisabled}
|
||||||
rows={1}
|
rows={1}
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -123,14 +122,13 @@ export function ChatInput({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||||
onClick={toggleRecording}
|
onClick={toggleRecording}
|
||||||
disabled={disabled || isTranscribing || isStreaming}
|
disabled={disabled || isTranscribing}
|
||||||
className={cn(
|
className={cn(
|
||||||
isRecording
|
isRecording
|
||||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||||
: isTranscribing
|
: isTranscribing
|
||||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||||
isStreaming && "opacity-40",
|
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{isTranscribing ? (
|
{isTranscribing ? (
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ export function AudioWaveform({
|
|||||||
// Create audio context and analyser
|
// Create audio context and analyser
|
||||||
const audioContext = new AudioContext();
|
const audioContext = new AudioContext();
|
||||||
const analyser = audioContext.createAnalyser();
|
const analyser = audioContext.createAnalyser();
|
||||||
analyser.fftSize = 256;
|
analyser.fftSize = 512;
|
||||||
analyser.smoothingTimeConstant = 0.3;
|
analyser.smoothingTimeConstant = 0.8;
|
||||||
|
|
||||||
// Connect the stream to the analyser
|
// Connect the stream to the analyser
|
||||||
const source = audioContext.createMediaStreamSource(stream);
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
@@ -73,11 +73,10 @@ export function AudioWaveform({
|
|||||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize amplitude (0-128 range) to 0-1
|
// Map amplitude (0-128) to bar height
|
||||||
const normalized = maxAmplitude / 128;
|
const normalized = (maxAmplitude / 128) * 255;
|
||||||
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
|
const height =
|
||||||
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
|
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
||||||
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
|
|
||||||
newBars.push(height);
|
newBars.push(height);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ export function useVoiceRecording({
|
|||||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||||
);
|
);
|
||||||
|
|
||||||
const showMicButton = isSupported;
|
const showMicButton = isSupported && !isStreaming;
|
||||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||||
|
|
||||||
// Cleanup on unmount
|
// Cleanup on unmount
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ export type ChatMessageData =
|
|||||||
toolName: string;
|
toolName: string;
|
||||||
toolId: string;
|
toolId: string;
|
||||||
operationId: string;
|
operationId: string;
|
||||||
taskId?: string; // For SSE reconnection
|
|
||||||
message: string;
|
message: string;
|
||||||
timestamp?: string | Date;
|
timestamp?: string | Date;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ export function MessageList({
|
|||||||
isStreaming,
|
isStreaming,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Keeps this for debugging purposes 💆🏽
|
||||||
|
*/
|
||||||
|
console.log(messages);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{/* Top fade shadow */}
|
{/* Top fade shadow */}
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import { INITIAL_STREAM_ID } from "./chat-constants";
|
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -11,14 +10,8 @@ import {
|
|||||||
parseSSELine,
|
parseSSELine,
|
||||||
} from "./stream-utils";
|
} from "./stream-utils";
|
||||||
|
|
||||||
function notifySubscribers(
|
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||||
stream: ActiveStream,
|
stream.chunks.push(chunk);
|
||||||
chunk: StreamChunk,
|
|
||||||
skipStore = false,
|
|
||||||
) {
|
|
||||||
if (!skipStore) {
|
|
||||||
stream.chunks.push(chunk);
|
|
||||||
}
|
|
||||||
for (const callback of stream.onChunkCallbacks) {
|
for (const callback of stream.onChunkCallbacks) {
|
||||||
try {
|
try {
|
||||||
callback(chunk);
|
callback(chunk);
|
||||||
@@ -28,114 +21,36 @@ function notifySubscribers(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interface StreamExecutionOptions {
|
export async function executeStream(
|
||||||
stream: ActiveStream;
|
stream: ActiveStream,
|
||||||
mode: "new" | "reconnect";
|
message: string,
|
||||||
message?: string;
|
isUserMessage: boolean,
|
||||||
isUserMessage?: boolean;
|
context?: { url: string; content: string },
|
||||||
context?: { url: string; content: string };
|
retryCount: number = 0,
|
||||||
taskId?: string;
|
|
||||||
lastMessageId?: string;
|
|
||||||
retryCount?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function executeStreamInternal(
|
|
||||||
options: StreamExecutionOptions,
|
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const {
|
|
||||||
stream,
|
|
||||||
mode,
|
|
||||||
message,
|
|
||||||
isUserMessage,
|
|
||||||
context,
|
|
||||||
taskId,
|
|
||||||
lastMessageId = INITIAL_STREAM_ID,
|
|
||||||
retryCount = 0,
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
const { sessionId, abortController } = stream;
|
const { sessionId, abortController } = stream;
|
||||||
const isReconnect = mode === "reconnect";
|
|
||||||
|
|
||||||
if (isReconnect) {
|
|
||||||
if (!taskId) {
|
|
||||||
throw new Error("taskId is required for reconnect mode");
|
|
||||||
}
|
|
||||||
if (lastMessageId === null || lastMessageId === undefined) {
|
|
||||||
throw new Error("lastMessageId is required for reconnect mode");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!message) {
|
|
||||||
throw new Error("message is required for new stream mode");
|
|
||||||
}
|
|
||||||
if (isUserMessage === undefined) {
|
|
||||||
throw new Error("isUserMessage is required for new stream mode");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let url: string;
|
const url = `/api/chat/sessions/${sessionId}/stream`;
|
||||||
let fetchOptions: RequestInit;
|
const body = JSON.stringify({
|
||||||
|
message,
|
||||||
|
is_user_message: isUserMessage,
|
||||||
|
context: context || null,
|
||||||
|
});
|
||||||
|
|
||||||
if (isReconnect) {
|
const response = await fetch(url, {
|
||||||
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
method: "POST",
|
||||||
fetchOptions = {
|
headers: {
|
||||||
method: "GET",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
Accept: "text/event-stream",
|
||||||
Accept: "text/event-stream",
|
},
|
||||||
},
|
body,
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
};
|
});
|
||||||
} else {
|
|
||||||
url = `/api/chat/sessions/${sessionId}/stream`;
|
|
||||||
fetchOptions = {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Accept: "text/event-stream",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
message,
|
|
||||||
is_user_message: isUserMessage,
|
|
||||||
context: context || null,
|
|
||||||
}),
|
|
||||||
signal: abortController.signal,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(url, fetchOptions);
|
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorText = await response.text();
|
const errorText = await response.text();
|
||||||
let errorCode: string | undefined;
|
throw new Error(errorText || `HTTP ${response.status}`);
|
||||||
let errorMessage = errorText || `HTTP ${response.status}`;
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(errorText);
|
|
||||||
if (parsed.detail) {
|
|
||||||
const detail =
|
|
||||||
typeof parsed.detail === "string"
|
|
||||||
? parsed.detail
|
|
||||||
: parsed.detail.message || JSON.stringify(parsed.detail);
|
|
||||||
errorMessage = detail;
|
|
||||||
errorCode =
|
|
||||||
typeof parsed.detail === "object" ? parsed.detail.code : undefined;
|
|
||||||
}
|
|
||||||
} catch {}
|
|
||||||
|
|
||||||
const isPermanentError =
|
|
||||||
isReconnect &&
|
|
||||||
(response.status === 404 ||
|
|
||||||
response.status === 403 ||
|
|
||||||
response.status === 410);
|
|
||||||
|
|
||||||
const error = new Error(errorMessage) as Error & {
|
|
||||||
status?: number;
|
|
||||||
isPermanent?: boolean;
|
|
||||||
taskErrorCode?: string;
|
|
||||||
};
|
|
||||||
error.status = response.status;
|
|
||||||
error.isPermanent = isPermanentError;
|
|
||||||
error.taskErrorCode = errorCode;
|
|
||||||
throw error;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!response.body) {
|
if (!response.body) {
|
||||||
@@ -189,7 +104,9 @@ async function executeStreamInternal(
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} catch {}
|
} catch (err) {
|
||||||
|
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -200,17 +117,19 @@ async function executeStreamInternal(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isPermanentError =
|
if (retryCount < MAX_RETRIES) {
|
||||||
err instanceof Error &&
|
|
||||||
(err as Error & { isPermanent?: boolean }).isPermanent;
|
|
||||||
|
|
||||||
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
|
||||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||||
|
console.log(
|
||||||
|
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||||
|
);
|
||||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||||
return executeStreamInternal({
|
return executeStream(
|
||||||
...options,
|
stream,
|
||||||
retryCount: retryCount + 1,
|
message,
|
||||||
});
|
isUserMessage,
|
||||||
|
context,
|
||||||
|
retryCount + 1,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.status = "error";
|
stream.status = "error";
|
||||||
@@ -221,35 +140,3 @@ async function executeStreamInternal(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function executeStream(
|
|
||||||
stream: ActiveStream,
|
|
||||||
message: string,
|
|
||||||
isUserMessage: boolean,
|
|
||||||
context?: { url: string; content: string },
|
|
||||||
retryCount: number = 0,
|
|
||||||
): Promise<void> {
|
|
||||||
return executeStreamInternal({
|
|
||||||
stream,
|
|
||||||
mode: "new",
|
|
||||||
message,
|
|
||||||
isUserMessage,
|
|
||||||
context,
|
|
||||||
retryCount,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function executeTaskReconnect(
|
|
||||||
stream: ActiveStream,
|
|
||||||
taskId: string,
|
|
||||||
lastMessageId: string = INITIAL_STREAM_ID,
|
|
||||||
retryCount: number = 0,
|
|
||||||
): Promise<void> {
|
|
||||||
return executeStreamInternal({
|
|
||||||
stream,
|
|
||||||
mode: "reconnect",
|
|
||||||
taskId,
|
|
||||||
lastMessageId,
|
|
||||||
retryCount,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ export function normalizeStreamChunk(
|
|||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "text-delta":
|
case "text-delta":
|
||||||
// Vercel AI SDK sends "delta" for text content
|
|
||||||
return { type: "text_chunk", content: chunk.delta };
|
return { type: "text_chunk", content: chunk.delta };
|
||||||
case "text-end":
|
case "text-end":
|
||||||
return { type: "text_ended" };
|
return { type: "text_ended" };
|
||||||
@@ -64,10 +63,6 @@ export function normalizeStreamChunk(
|
|||||||
case "finish":
|
case "finish":
|
||||||
return { type: "stream_end" };
|
return { type: "stream_end" };
|
||||||
case "start":
|
case "start":
|
||||||
// Start event with optional taskId for reconnection
|
|
||||||
return chunk.taskId
|
|
||||||
? { type: "stream_start", taskId: chunk.taskId }
|
|
||||||
: null;
|
|
||||||
case "text-start":
|
case "text-start":
|
||||||
return null;
|
return null;
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
|
|||||||
@@ -41,17 +41,7 @@ export function HostScopedCredentialsModal({
|
|||||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
host: z
|
host: z.string().min(1, "Host is required"),
|
||||||
.string()
|
|
||||||
.min(1, "Host is required")
|
|
||||||
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
|
||||||
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
|
||||||
})
|
|
||||||
.refine((val) => !val.includes("/"), {
|
|
||||||
message:
|
|
||||||
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
|
||||||
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
|
||||||
}),
|
|
||||||
title: z.string().optional(),
|
title: z.string().optional(),
|
||||||
headers: z.record(z.string()).optional(),
|
headers: z.record(z.string()).optional(),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
||||||
@@ -23,11 +24,11 @@ interface Props {
|
|||||||
export function NavbarLink({ name, href }: Props) {
|
export function NavbarLink({ name, href }: Props) {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const expectedHomeRoute = isChatEnabled ? "/copilot" : "/library";
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
const isActive =
|
const isActive =
|
||||||
href === expectedHomeRoute
|
href === homepageRoute
|
||||||
? pathname === "/" || pathname.startsWith(expectedHomeRoute)
|
? pathname === "/" || pathname.startsWith(homepageRoute)
|
||||||
: pathname.includes(href);
|
: pathname.includes(href);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
|
||||||
import { WalletIcon } from "@phosphor-icons/react";
|
import { WalletIcon } from "@phosphor-icons/react";
|
||||||
import { PopoverClose } from "@radix-ui/react-popover";
|
import { PopoverClose } from "@radix-ui/react-popover";
|
||||||
import { X } from "lucide-react";
|
import { X } from "lucide-react";
|
||||||
@@ -174,6 +175,7 @@ export function Wallet() {
|
|||||||
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
||||||
const [flash, setFlash] = useState(false);
|
const [flash, setFlash] = useState(false);
|
||||||
const [walletOpen, setWalletOpen] = useState(false);
|
const [walletOpen, setWalletOpen] = useState(false);
|
||||||
|
const [lastSeenCredits, setLastSeenCredits] = useState<number | null>(null);
|
||||||
|
|
||||||
const totalCount = useMemo(() => {
|
const totalCount = useMemo(() => {
|
||||||
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
||||||
@@ -198,6 +200,38 @@ export function Wallet() {
|
|||||||
setCompletedCount(completed);
|
setCompletedCount(completed);
|
||||||
}, [groups, state?.completedSteps]);
|
}, [groups, state?.completedSteps]);
|
||||||
|
|
||||||
|
// Load last seen credits from localStorage once on mount
|
||||||
|
useEffect(() => {
|
||||||
|
const stored = storage.get(StorageKey.WALLET_LAST_SEEN_CREDITS);
|
||||||
|
if (stored !== undefined && stored !== null) {
|
||||||
|
const parsed = parseFloat(stored);
|
||||||
|
if (!Number.isNaN(parsed)) setLastSeenCredits(parsed);
|
||||||
|
else setLastSeenCredits(0);
|
||||||
|
} else {
|
||||||
|
setLastSeenCredits(0);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Auto-open once if never shown, otherwise open only when credits increase beyond last seen
|
||||||
|
useEffect(() => {
|
||||||
|
if (typeof credits !== "number") return;
|
||||||
|
// Open once for first-time users
|
||||||
|
if (state && state.walletShown === false) {
|
||||||
|
requestAnimationFrame(() => setWalletOpen(true));
|
||||||
|
// Mark as shown so it won't reopen on every reload
|
||||||
|
updateState({ walletShown: true });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Open if user gained more credits than last acknowledged
|
||||||
|
if (
|
||||||
|
lastSeenCredits !== null &&
|
||||||
|
credits > lastSeenCredits &&
|
||||||
|
walletOpen === false
|
||||||
|
) {
|
||||||
|
requestAnimationFrame(() => setWalletOpen(true));
|
||||||
|
}
|
||||||
|
}, [credits, lastSeenCredits, state?.walletShown, updateState, walletOpen]);
|
||||||
|
|
||||||
const onWalletOpen = useCallback(async () => {
|
const onWalletOpen = useCallback(async () => {
|
||||||
if (!state?.walletShown) {
|
if (!state?.walletShown) {
|
||||||
updateState({ walletShown: true });
|
updateState({ walletShown: true });
|
||||||
@@ -290,7 +324,19 @@ export function Wallet() {
|
|||||||
if (credits === null || !state) return null;
|
if (credits === null || !state) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover open={walletOpen} onOpenChange={(open) => setWalletOpen(open)}>
|
<Popover
|
||||||
|
open={walletOpen}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
setWalletOpen(open);
|
||||||
|
if (!open) {
|
||||||
|
// Persist the latest acknowledged credits so we only auto-open on future gains
|
||||||
|
if (typeof credits === "number") {
|
||||||
|
storage.set(StorageKey.WALLET_LAST_SEEN_CREDITS, String(credits));
|
||||||
|
setLastSeenCredits(credits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
<PopoverTrigger asChild>
|
<PopoverTrigger asChild>
|
||||||
<div className="relative inline-block">
|
<div className="relative inline-block">
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ export default function useAgentGraph(
|
|||||||
>(null);
|
>(null);
|
||||||
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
||||||
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
||||||
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS) as string[];
|
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS);
|
||||||
|
|
||||||
// Filter blocks based on beta flags
|
// Filter blocks based on beta flags
|
||||||
const availableBlocks = useMemo(() => {
|
const availableBlocks = useMemo(() => {
|
||||||
|
|||||||
@@ -11,3 +11,10 @@ export const API_KEY_HEADER_NAME = "X-API-Key";
|
|||||||
|
|
||||||
// Layout
|
// Layout
|
||||||
export const NAVBAR_HEIGHT_PX = 60;
|
export const NAVBAR_HEIGHT_PX = 60;
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
export function getHomepageRoute(isChatEnabled?: boolean | null): string {
|
||||||
|
if (isChatEnabled === true) return "/copilot";
|
||||||
|
if (isChatEnabled === false) return "/library";
|
||||||
|
return "/";
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { type CookieOptions } from "@supabase/ssr";
|
import { type CookieOptions } from "@supabase/ssr";
|
||||||
@@ -70,7 +71,7 @@ export function getRedirectPath(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isAdminPage(path) && userRole !== "admin") {
|
if (isAdminPage(path) && userRole !== "admin") {
|
||||||
return "/";
|
return getHomepageRoute();
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { createServerClient } from "@supabase/ssr";
|
import { createServerClient } from "@supabase/ssr";
|
||||||
import { NextResponse, type NextRequest } from "next/server";
|
import { NextResponse, type NextRequest } from "next/server";
|
||||||
@@ -66,7 +67,7 @@ export async function updateSession(request: NextRequest) {
|
|||||||
|
|
||||||
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
||||||
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
||||||
url.pathname = "/";
|
url.pathname = getHomepageRoute();
|
||||||
return NextResponse.redirect(url);
|
return NextResponse.redirect(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ import {
|
|||||||
WebSocketNotification,
|
WebSocketNotification,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname, useRouter } from "next/navigation";
|
import { usePathname, useRouter } from "next/navigation";
|
||||||
import {
|
import {
|
||||||
@@ -102,6 +104,8 @@ export default function OnboardingProvider({
|
|||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const { isLoggedIn } = useSupabase();
|
const { isLoggedIn } = useSupabase();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
useOnboardingTimezoneDetection();
|
useOnboardingTimezoneDetection();
|
||||||
|
|
||||||
@@ -146,7 +150,7 @@ export default function OnboardingProvider({
|
|||||||
if (isOnOnboardingRoute) {
|
if (isOnOnboardingRoute) {
|
||||||
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
||||||
if (!enabled) {
|
if (!enabled) {
|
||||||
router.push("/");
|
router.push(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,7 +162,7 @@ export default function OnboardingProvider({
|
|||||||
isOnOnboardingRoute &&
|
isOnOnboardingRoute &&
|
||||||
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
||||||
) {
|
) {
|
||||||
router.push("/");
|
router.push(homepageRoute);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to initialize onboarding:", error);
|
console.error("Failed to initialize onboarding:", error);
|
||||||
@@ -173,7 +177,7 @@ export default function OnboardingProvider({
|
|||||||
}
|
}
|
||||||
|
|
||||||
initializeOnboarding();
|
initializeOnboarding();
|
||||||
}, [api, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
}, [api, homepageRoute, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
||||||
|
|
||||||
const handleOnboardingNotification = useCallback(
|
const handleOnboardingNotification = useCallback(
|
||||||
(notification: WebSocketNotification) => {
|
(notification: WebSocketNotification) => {
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ function getPostHogCredentials() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function getLaunchDarklyClientId() {
|
|
||||||
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isProductionBuild() {
|
function isProductionBuild() {
|
||||||
return process.env.NODE_ENV === "production";
|
return process.env.NODE_ENV === "production";
|
||||||
}
|
}
|
||||||
@@ -124,10 +120,7 @@ function isVercelPreview() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function areFeatureFlagsEnabled() {
|
function areFeatureFlagsEnabled() {
|
||||||
return (
|
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled";
|
||||||
process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true" &&
|
|
||||||
Boolean(process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function isPostHogEnabled() {
|
function isPostHogEnabled() {
|
||||||
@@ -150,7 +143,6 @@ export const environment = {
|
|||||||
getSupabaseAnonKey,
|
getSupabaseAnonKey,
|
||||||
getPreviewStealingDev,
|
getPreviewStealingDev,
|
||||||
getPostHogCredentials,
|
getPostHogCredentials,
|
||||||
getLaunchDarklyClientId,
|
|
||||||
// Assertions
|
// Assertions
|
||||||
isServerSide,
|
isServerSide,
|
||||||
isClientSide,
|
isClientSide,
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { ReactNode, useEffect, useState } from "react";
|
|
||||||
import { environment } from "../environment";
|
|
||||||
import { Flag, useGetFlag } from "./use-get-flag";
|
|
||||||
|
|
||||||
interface FeatureFlagRedirectProps {
|
|
||||||
flag: Flag;
|
|
||||||
whenDisabled: string;
|
|
||||||
children: ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function FeatureFlagPage({
|
|
||||||
flag,
|
|
||||||
whenDisabled,
|
|
||||||
children,
|
|
||||||
}: FeatureFlagRedirectProps) {
|
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
|
||||||
const router = useRouter();
|
|
||||||
const flagValue = useGetFlag(flag);
|
|
||||||
const ldClient = useLDClient();
|
|
||||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
const ldReady = Boolean(ldClient);
|
|
||||||
const flagEnabled = Boolean(flagValue);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const initialize = async () => {
|
|
||||||
if (!ldEnabled) {
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
|
||||||
if (ldEnabled && !ldReady) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
await ldClient?.waitForInitialization();
|
|
||||||
if (!flagEnabled) router.replace(whenDisabled);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
} finally {
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
initialize();
|
|
||||||
}, [ldReady, flagEnabled]);
|
|
||||||
|
|
||||||
return isLoading || !flagEnabled ? (
|
|
||||||
<LoadingSpinner size="large" cover />
|
|
||||||
) : (
|
|
||||||
<>{children}</>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { useEffect } from "react";
|
|
||||||
import { environment } from "../environment";
|
|
||||||
import { Flag, useGetFlag } from "./use-get-flag";
|
|
||||||
|
|
||||||
interface FeatureFlagRedirectProps {
|
|
||||||
flag: Flag;
|
|
||||||
whenEnabled: string;
|
|
||||||
whenDisabled: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function FeatureFlagRedirect({
|
|
||||||
flag,
|
|
||||||
whenEnabled,
|
|
||||||
whenDisabled,
|
|
||||||
}: FeatureFlagRedirectProps) {
|
|
||||||
const router = useRouter();
|
|
||||||
const flagValue = useGetFlag(flag);
|
|
||||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
const ldClient = useLDClient();
|
|
||||||
const ldReady = Boolean(ldClient);
|
|
||||||
const flagEnabled = Boolean(flagValue);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const initialize = async () => {
|
|
||||||
if (!ldEnabled) {
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
|
||||||
if (ldEnabled && !ldReady) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
await ldClient?.waitForInitialization();
|
|
||||||
router.replace(flagEnabled ? whenEnabled : whenDisabled);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
initialize();
|
|
||||||
}, [ldReady, flagEnabled]);
|
|
||||||
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { LDProvider } from "launchdarkly-react-client-sdk";
|
import { LDProvider } from "launchdarkly-react-client-sdk";
|
||||||
@@ -8,17 +7,17 @@ import type { ReactNode } from "react";
|
|||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { environment } from "../environment";
|
import { environment } from "../environment";
|
||||||
|
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
||||||
|
|
||||||
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||||
const { user, isUserLoading } = useSupabase();
|
const { user, isUserLoading } = useSupabase();
|
||||||
const envEnabled = environment.areFeatureFlagsEnabled();
|
const isCloud = environment.isCloud();
|
||||||
const clientId = environment.getLaunchDarklyClientId();
|
const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
|
||||||
|
|
||||||
const context = useMemo(() => {
|
const context = useMemo(() => {
|
||||||
if (isUserLoading) return;
|
if (isUserLoading || !user) {
|
||||||
|
|
||||||
if (!user) {
|
|
||||||
return {
|
return {
|
||||||
kind: "user" as const,
|
kind: "user" as const,
|
||||||
key: "anonymous",
|
key: "anonymous",
|
||||||
@@ -37,17 +36,15 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
|||||||
};
|
};
|
||||||
}, [user, isUserLoading]);
|
}, [user, isUserLoading]);
|
||||||
|
|
||||||
if (!envEnabled) {
|
if (!isLaunchDarklyConfigured) {
|
||||||
return <>{children}</>;
|
return <>{children}</>;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isUserLoading) {
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<LDProvider
|
<LDProvider
|
||||||
clientSideID={clientId ?? ""}
|
// Add this key prop. It will be 'anonymous' when logged out,
|
||||||
|
key={context.key}
|
||||||
|
clientSideID={clientId}
|
||||||
context={context}
|
context={context}
|
||||||
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
||||||
reactOptions={{ useCamelCaseFlagKeys: false }}
|
reactOptions={{ useCamelCaseFlagKeys: false }}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
||||||
import { environment } from "@/services/environment";
|
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
|
|
||||||
export enum Flag {
|
export enum Flag {
|
||||||
@@ -19,9 +18,24 @@ export enum Flag {
|
|||||||
CHAT = "chat",
|
CHAT = "chat",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type FlagValues = {
|
||||||
|
[Flag.BETA_BLOCKS]: string[];
|
||||||
|
[Flag.NEW_BLOCK_MENU]: boolean;
|
||||||
|
[Flag.NEW_AGENT_RUNS]: boolean;
|
||||||
|
[Flag.GRAPH_SEARCH]: boolean;
|
||||||
|
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
|
||||||
|
[Flag.NEW_FLOW_EDITOR]: boolean;
|
||||||
|
[Flag.BUILDER_VIEW_SWITCH]: boolean;
|
||||||
|
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
||||||
|
[Flag.AGENT_FAVORITING]: boolean;
|
||||||
|
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
|
||||||
|
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
|
||||||
|
[Flag.CHAT]: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||||
|
|
||||||
const defaultFlags = {
|
const mockFlags = {
|
||||||
[Flag.BETA_BLOCKS]: [],
|
[Flag.BETA_BLOCKS]: [],
|
||||||
[Flag.NEW_BLOCK_MENU]: false,
|
[Flag.NEW_BLOCK_MENU]: false,
|
||||||
[Flag.NEW_AGENT_RUNS]: false,
|
[Flag.NEW_AGENT_RUNS]: false,
|
||||||
@@ -36,16 +50,17 @@ const defaultFlags = {
|
|||||||
[Flag.CHAT]: false,
|
[Flag.CHAT]: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
type FlagValues = typeof defaultFlags;
|
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||||
|
|
||||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] {
|
|
||||||
const currentFlags = useFlags<FlagValues>();
|
const currentFlags = useFlags<FlagValues>();
|
||||||
const flagValue = currentFlags[flag];
|
const flagValue = currentFlags[flag];
|
||||||
const areFlagsEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
|
|
||||||
if (!areFlagsEnabled || isPwMockEnabled) {
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
return defaultFlags[flag];
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
|
||||||
|
if (!isLaunchDarklyConfigured || isPwMockEnabled) {
|
||||||
|
return mockFlags[flag];
|
||||||
}
|
}
|
||||||
|
|
||||||
return flagValue ?? defaultFlags[flag];
|
return flagValue ?? mockFlags[flag];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
extend-ignore = E203
|
|
||||||
exclude =
|
exclude =
|
||||||
.tox,
|
.tox,
|
||||||
__pycache__,
|
__pycache__,
|
||||||
*.pyc,
|
*.pyc,
|
||||||
.env,
|
.env
|
||||||
venv*,
|
venv*/*,
|
||||||
.venv,
|
.venv/*,
|
||||||
reports,
|
reports/*,
|
||||||
dist,
|
dist/*,
|
||||||
data,
|
data/*,
|
||||||
.benchmark_workspaces,
|
|
||||||
.autogpt,
|
|
||||||
|
|||||||
@@ -1,291 +0,0 @@
|
|||||||
# CLAUDE.md
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
||||||
|
|
||||||
## Project Overview
|
|
||||||
|
|
||||||
AutoGPT Classic is an experimental, **unsupported** project demonstrating autonomous GPT-4 operation. Dependencies will not be updated, and the codebase contains known vulnerabilities. This is preserved for educational/historical purposes.
|
|
||||||
|
|
||||||
## Repository Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
classic/
|
|
||||||
├── pyproject.toml # Single consolidated Poetry project
|
|
||||||
├── poetry.lock # Single lock file
|
|
||||||
├── forge/
|
|
||||||
│ └── forge/ # Core agent framework package
|
|
||||||
├── original_autogpt/
|
|
||||||
│ └── autogpt/ # AutoGPT agent package
|
|
||||||
├── direct_benchmark/
|
|
||||||
│ └── direct_benchmark/ # Benchmark harness package
|
|
||||||
└── benchmark/ # Challenge definitions (data, not code)
|
|
||||||
```
|
|
||||||
|
|
||||||
All packages are managed by a single `pyproject.toml` at the classic/ root.
|
|
||||||
|
|
||||||
## Common Commands
|
|
||||||
|
|
||||||
### Setup & Install
|
|
||||||
```bash
|
|
||||||
# Install everything from classic/ directory
|
|
||||||
cd classic
|
|
||||||
poetry install
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running Agents
|
|
||||||
```bash
|
|
||||||
# Run forge agent
|
|
||||||
poetry run python -m forge
|
|
||||||
|
|
||||||
# Run original autogpt server
|
|
||||||
poetry run serve --debug
|
|
||||||
|
|
||||||
# Run autogpt CLI
|
|
||||||
poetry run autogpt
|
|
||||||
```
|
|
||||||
|
|
||||||
Agents run on `http://localhost:8000` by default.
|
|
||||||
|
|
||||||
### Benchmarking
|
|
||||||
```bash
|
|
||||||
# Run benchmarks
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
|
|
||||||
# Run specific strategies and models
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo \
|
|
||||||
--models claude \
|
|
||||||
--parallel 4
|
|
||||||
|
|
||||||
# Run a single test
|
|
||||||
poetry run direct-benchmark run --tests ReadFile
|
|
||||||
|
|
||||||
# List available commands
|
|
||||||
poetry run direct-benchmark --help
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
```bash
|
|
||||||
poetry run pytest # All tests
|
|
||||||
poetry run pytest forge/tests/ # Forge tests only
|
|
||||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
|
||||||
poetry run pytest -k test_name # Single test by name
|
|
||||||
poetry run pytest path/to/test.py # Specific test file
|
|
||||||
poetry run pytest --cov # With coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
### Linting & Formatting
|
|
||||||
|
|
||||||
Run from the classic/ directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Format everything (recommended to run together)
|
|
||||||
poetry run black . && poetry run isort .
|
|
||||||
|
|
||||||
# Check formatting (CI-style, no changes)
|
|
||||||
poetry run black --check . && poetry run isort --check-only .
|
|
||||||
|
|
||||||
# Lint
|
|
||||||
poetry run flake8 # Style linting
|
|
||||||
|
|
||||||
# Type check
|
|
||||||
poetry run pyright # Type checking (some errors are expected in infrastructure code)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: Always run linters over the entire directory, not specific files, for best results.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Forge (Core Framework)
|
|
||||||
The `forge` package is the foundation that other components depend on:
|
|
||||||
- `forge/agent/` - Agent implementation and protocols
|
|
||||||
- `forge/llm/` - Multi-provider LLM integrations (OpenAI, Anthropic, Groq, LiteLLM)
|
|
||||||
- `forge/components/` - Reusable agent components
|
|
||||||
- `forge/file_storage/` - File system abstraction
|
|
||||||
- `forge/config/` - Configuration management
|
|
||||||
|
|
||||||
### Original AutoGPT
|
|
||||||
- `original_autogpt/autogpt/app/` - CLI application entry points
|
|
||||||
- `original_autogpt/autogpt/agents/` - Agent implementations
|
|
||||||
- `original_autogpt/autogpt/agent_factory/` - Agent creation logic
|
|
||||||
|
|
||||||
### Direct Benchmark
|
|
||||||
Benchmark harness for testing agent performance:
|
|
||||||
- `direct_benchmark/direct_benchmark/` - CLI and harness code
|
|
||||||
- `benchmark/agbenchmark/challenges/` - Test cases organized by category (code, retrieval, data, etc.)
|
|
||||||
- Reports generated in `direct_benchmark/reports/`
|
|
||||||
|
|
||||||
### Package Structure
|
|
||||||
All three packages are included in a single Poetry project. Imports are fully qualified:
|
|
||||||
- `from forge.agent.base import BaseAgent`
|
|
||||||
- `from autogpt.agents.agent import Agent`
|
|
||||||
- `from direct_benchmark.harness import BenchmarkHarness`
|
|
||||||
|
|
||||||
## Code Style
|
|
||||||
|
|
||||||
- Python 3.12 target
|
|
||||||
- Line length: 88 characters (Black default)
|
|
||||||
- Black for formatting, isort for imports (profile="black")
|
|
||||||
- Type hints with Pyright checking
|
|
||||||
|
|
||||||
## Testing Patterns
|
|
||||||
|
|
||||||
- Async support via pytest-asyncio
|
|
||||||
- Fixtures defined in `conftest.py` files provide: `tmp_project_root`, `storage`, `config`, `llm_provider`, `agent`
|
|
||||||
- Tests requiring API keys (OPENAI_API_KEY, ANTHROPIC_API_KEY) will skip if not set
|
|
||||||
|
|
||||||
## Environment Setup
|
|
||||||
|
|
||||||
Copy `.env.example` to `.env` in the relevant directory and add your API keys:
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your OPENAI_API_KEY, etc.
|
|
||||||
```
|
|
||||||
|
|
||||||
## Workspaces
|
|
||||||
|
|
||||||
Agents operate within a **workspace** - a directory containing all agent data and files. The workspace root defaults to the current working directory.
|
|
||||||
|
|
||||||
### Workspace Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
{workspace}/
|
|
||||||
├── .autogpt/
|
|
||||||
│ ├── autogpt.yaml # Workspace-level permissions
|
|
||||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
|
||||||
│ └── agents/
|
|
||||||
│ └── AutoGPT-{agent_id}/
|
|
||||||
│ ├── state.json # Agent profile, directives, action history
|
|
||||||
│ ├── permissions.yaml # Agent-specific permission overrides
|
|
||||||
│ └── workspace/ # Agent's sandboxed working directory
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key Concepts
|
|
||||||
|
|
||||||
- **Multiple agents** can coexist in the same workspace (each gets its own subdirectory)
|
|
||||||
- **File access** is sandboxed to the agent's `workspace/` directory by default
|
|
||||||
- **State persistence** - agent state saves to `state.json` and survives across sessions
|
|
||||||
- **Storage backends** - supports local filesystem, S3, and GCS (via `FILE_STORAGE_BACKEND` env var)
|
|
||||||
|
|
||||||
### Specifying a Workspace
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Default: uses current directory
|
|
||||||
cd /path/to/my/project && poetry run autogpt
|
|
||||||
|
|
||||||
# Or specify explicitly via CLI (if supported)
|
|
||||||
poetry run autogpt --workspace /path/to/workspace
|
|
||||||
```
|
|
||||||
|
|
||||||
## Settings Location
|
|
||||||
|
|
||||||
Configuration uses a **layered system** with three levels (in order of precedence):
|
|
||||||
|
|
||||||
### 1. Environment Variables (Global)
|
|
||||||
|
|
||||||
Loaded from `.env` file in the working directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Required
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Optional LLM settings
|
|
||||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
|
||||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
|
||||||
EMBEDDING_MODEL=text-embedding-3-small
|
|
||||||
|
|
||||||
# Optional search providers (for web search component)
|
|
||||||
TAVILY_API_KEY=tvly-...
|
|
||||||
SERPER_API_KEY=...
|
|
||||||
GOOGLE_API_KEY=...
|
|
||||||
GOOGLE_CUSTOM_SEARCH_ENGINE_ID=...
|
|
||||||
|
|
||||||
# Optional infrastructure
|
|
||||||
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR
|
|
||||||
DATABASE_STRING=sqlite:///agent.db # Agent Protocol database
|
|
||||||
PORT=8000 # Server port
|
|
||||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Workspace Settings (`{workspace}/.autogpt/autogpt.yaml`)
|
|
||||||
|
|
||||||
Workspace-wide permissions that apply to **all agents** in this workspace:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- read_file({workspace}/**)
|
|
||||||
- write_to_file({workspace}/**)
|
|
||||||
- list_folder({workspace}/**)
|
|
||||||
- web_search(*)
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- read_file(**.env)
|
|
||||||
- read_file(**.env.*)
|
|
||||||
- read_file(**.key)
|
|
||||||
- read_file(**.pem)
|
|
||||||
- execute_shell(rm -rf:*)
|
|
||||||
- execute_shell(sudo:*)
|
|
||||||
```
|
|
||||||
|
|
||||||
Auto-generated with sensible defaults if missing.
|
|
||||||
|
|
||||||
### 3. Agent Settings (`{workspace}/.autogpt/agents/{id}/permissions.yaml`)
|
|
||||||
|
|
||||||
Agent-specific permission overrides:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- execute_python(*)
|
|
||||||
- web_search(*)
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- execute_shell(*)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Permissions
|
|
||||||
|
|
||||||
The permission system uses **pattern matching** with a **first-match-wins** evaluation order.
|
|
||||||
|
|
||||||
### Permission Check Order
|
|
||||||
|
|
||||||
1. Agent deny list → **Block**
|
|
||||||
2. Workspace deny list → **Block**
|
|
||||||
3. Agent allow list → **Allow**
|
|
||||||
4. Workspace allow list → **Allow**
|
|
||||||
5. Session denied list → **Block** (commands denied during this session)
|
|
||||||
6. **Prompt user** → Interactive approval (if in interactive mode)
|
|
||||||
|
|
||||||
### Pattern Syntax
|
|
||||||
|
|
||||||
Format: `command_name(glob_pattern)`
|
|
||||||
|
|
||||||
| Pattern | Description |
|
|
||||||
|---------|-------------|
|
|
||||||
| `read_file({workspace}/**)` | Read any file in workspace (recursive) |
|
|
||||||
| `write_to_file({workspace}/*.txt)` | Write only .txt files in workspace root |
|
|
||||||
| `execute_shell(python:**)` | Execute Python commands only |
|
|
||||||
| `execute_shell(git:*)` | Execute any git command |
|
|
||||||
| `web_search(*)` | Allow all web searches |
|
|
||||||
|
|
||||||
Special tokens:
|
|
||||||
- `{workspace}` - Replaced with actual workspace path
|
|
||||||
- `**` - Matches any path including `/`
|
|
||||||
- `*` - Matches any characters except `/`
|
|
||||||
|
|
||||||
### Interactive Approval Scopes
|
|
||||||
|
|
||||||
When prompted for permission, users can choose:
|
|
||||||
|
|
||||||
| Scope | Effect |
|
|
||||||
|-------|--------|
|
|
||||||
| **Once** | Allow this one time only (not saved) |
|
|
||||||
| **Agent** | Always allow for this agent (saves to agent `permissions.yaml`) |
|
|
||||||
| **Workspace** | Always allow for all agents (saves to `autogpt.yaml`) |
|
|
||||||
| **Deny** | Deny this command (saves to appropriate deny list) |
|
|
||||||
|
|
||||||
### Default Security
|
|
||||||
|
|
||||||
Out of the box, the following are **denied by default**:
|
|
||||||
- Reading sensitive files (`.env`, `.key`, `.pem`)
|
|
||||||
- Destructive shell commands (`rm -rf`, `sudo`)
|
|
||||||
- Operations outside the workspace directory
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
ARG BUILD_TYPE=dev
|
ARG BUILD_TYPE=dev
|
||||||
|
|
||||||
# Use an official Python base image from the Docker Hub
|
# Use an official Python base image from the Docker Hub
|
||||||
FROM python:3.12-slim AS autogpt-base
|
FROM python:3.10-slim AS autogpt-base
|
||||||
|
|
||||||
# Install browsers
|
# Install browsers
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
@@ -34,6 +34,9 @@ COPY original_autogpt/pyproject.toml original_autogpt/poetry.lock ./
|
|||||||
# Include forge so it can be used as a path dependency
|
# Include forge so it can be used as a path dependency
|
||||||
COPY forge/ ../forge
|
COPY forge/ ../forge
|
||||||
|
|
||||||
|
# Include frontend
|
||||||
|
COPY frontend/ ../frontend
|
||||||
|
|
||||||
# Set the entrypoint
|
# Set the entrypoint
|
||||||
ENTRYPOINT ["poetry", "run", "autogpt"]
|
ENTRYPOINT ["poetry", "run", "autogpt"]
|
||||||
CMD []
|
CMD []
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ AutoGPT Classic was an experimental project to demonstrate autonomous GPT-4 oper
|
|||||||
|
|
||||||
## Project Status
|
## Project Status
|
||||||
|
|
||||||
**This project is unsupported, and dependencies will not be updated.** It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform).
|
⚠️ **This project is unsupported, and dependencies will not be updated. It was an experiment that has concluded its initial research phase. If you want to use AutoGPT, you should use the [AutoGPT Platform](/autogpt_platform)**
|
||||||
|
|
||||||
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
For those interested in autonomous AI agents, we recommend exploring more actively maintained alternatives or referring to this codebase for educational purposes only.
|
||||||
|
|
||||||
@@ -16,171 +16,37 @@ AutoGPT Classic was one of the first implementations of autonomous AI agents - A
|
|||||||
- Learn from the results and adjust its approach
|
- Learn from the results and adjust its approach
|
||||||
- Chain multiple actions together to achieve an objective
|
- Chain multiple actions together to achieve an objective
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- 🔄 Autonomous task chaining
|
||||||
|
- 🛠 Tool and API integration capabilities
|
||||||
|
- 💾 Memory management for context retention
|
||||||
|
- 🔍 Web browsing and information gathering
|
||||||
|
- 📝 File operations and content creation
|
||||||
|
- 🔄 Self-prompting and task breakdown
|
||||||
|
|
||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
```
|
The project is organized into several key components:
|
||||||
classic/
|
- `/benchmark` - Performance testing tools
|
||||||
├── pyproject.toml # Single consolidated Poetry project
|
- `/forge` - Core autonomous agent framework
|
||||||
├── poetry.lock # Single lock file
|
- `/frontend` - User interface components
|
||||||
├── forge/ # Core autonomous agent framework
|
- `/original_autogpt` - Original implementation
|
||||||
├── original_autogpt/ # Original implementation
|
|
||||||
├── direct_benchmark/ # Benchmark harness
|
|
||||||
└── benchmark/ # Challenge definitions (data)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Prerequisites
|
While this project is no longer actively maintained, you can still explore the codebase:
|
||||||
|
|
||||||
- Python 3.12+
|
|
||||||
- [Poetry](https://python-poetry.org/docs/#installation)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
```bash
|
```bash
|
||||||
# Clone the repository
|
|
||||||
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
git clone https://github.com/Significant-Gravitas/AutoGPT.git
|
||||||
cd classic
|
cd classic
|
||||||
|
|
||||||
# Install everything
|
|
||||||
poetry install
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration
|
2. Review the documentation:
|
||||||
|
- For reference, see the [documentation](https://docs.agpt.co). You can browse at the same point in time as this commit so the docs don't change.
|
||||||
Configuration uses a layered system:
|
- Check `CLI-USAGE.md` for command-line interface details
|
||||||
|
- Refer to `TROUBLESHOOTING.md` for common issues
|
||||||
1. **Environment variables** (`.env` file)
|
|
||||||
2. **Workspace settings** (`.autogpt/autogpt.yaml`)
|
|
||||||
3. **Agent settings** (`.autogpt/agents/{id}/permissions.yaml`)
|
|
||||||
|
|
||||||
Copy the example environment file and add your API keys:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cp .env.example .env
|
|
||||||
```
|
|
||||||
|
|
||||||
Key environment variables:
|
|
||||||
```bash
|
|
||||||
# Required
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Optional LLM settings
|
|
||||||
SMART_LLM=gpt-4o # Model for complex reasoning
|
|
||||||
FAST_LLM=gpt-4o-mini # Model for simple tasks
|
|
||||||
|
|
||||||
# Optional search providers
|
|
||||||
TAVILY_API_KEY=tvly-...
|
|
||||||
SERPER_API_KEY=...
|
|
||||||
|
|
||||||
# Optional infrastructure
|
|
||||||
LOG_LEVEL=DEBUG
|
|
||||||
PORT=8000
|
|
||||||
FILE_STORAGE_BACKEND=local # local, s3, or gcs
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running
|
|
||||||
|
|
||||||
All commands run from the `classic/` directory:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run forge agent
|
|
||||||
poetry run python -m forge
|
|
||||||
|
|
||||||
# Run original autogpt server
|
|
||||||
poetry run serve --debug
|
|
||||||
|
|
||||||
# Run autogpt CLI
|
|
||||||
poetry run autogpt
|
|
||||||
```
|
|
||||||
|
|
||||||
Agents run on `http://localhost:8000` by default.
|
|
||||||
|
|
||||||
### Benchmarking
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
```bash
|
|
||||||
poetry run pytest # All tests
|
|
||||||
poetry run pytest forge/tests/ # Forge tests only
|
|
||||||
poetry run pytest original_autogpt/tests/ # AutoGPT tests only
|
|
||||||
```
|
|
||||||
|
|
||||||
## Workspaces
|
|
||||||
|
|
||||||
Agents operate within a **workspace** directory that contains all agent data and files:
|
|
||||||
|
|
||||||
```
|
|
||||||
{workspace}/
|
|
||||||
├── .autogpt/
|
|
||||||
│ ├── autogpt.yaml # Workspace-level permissions
|
|
||||||
│ ├── ap_server.db # Agent Protocol database (server mode)
|
|
||||||
│ └── agents/
|
|
||||||
│ └── AutoGPT-{agent_id}/
|
|
||||||
│ ├── state.json # Agent profile, directives, history
|
|
||||||
│ ├── permissions.yaml # Agent-specific permissions
|
|
||||||
│ └── workspace/ # Agent's sandboxed working directory
|
|
||||||
```
|
|
||||||
|
|
||||||
- The workspace defaults to the current working directory
|
|
||||||
- Multiple agents can coexist in the same workspace
|
|
||||||
- Agent file access is sandboxed to their `workspace/` subdirectory
|
|
||||||
- State persists across sessions via `state.json`
|
|
||||||
|
|
||||||
## Permissions
|
|
||||||
|
|
||||||
AutoGPT uses a **layered permission system** with pattern matching:
|
|
||||||
|
|
||||||
### Permission Files
|
|
||||||
|
|
||||||
| File | Scope | Location |
|
|
||||||
|------|-------|----------|
|
|
||||||
| `autogpt.yaml` | All agents in workspace | `.autogpt/autogpt.yaml` |
|
|
||||||
| `permissions.yaml` | Single agent | `.autogpt/agents/{id}/permissions.yaml` |
|
|
||||||
|
|
||||||
### Permission Format
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
allow:
|
|
||||||
- read_file({workspace}/**) # Read any file in workspace
|
|
||||||
- write_to_file({workspace}/**) # Write any file in workspace
|
|
||||||
- web_search(*) # All web searches
|
|
||||||
|
|
||||||
deny:
|
|
||||||
- read_file(**.env) # Block .env files
|
|
||||||
- execute_shell(sudo:*) # Block sudo commands
|
|
||||||
```
|
|
||||||
|
|
||||||
### Check Order (First Match Wins)
|
|
||||||
|
|
||||||
1. Agent deny → Block
|
|
||||||
2. Workspace deny → Block
|
|
||||||
3. Agent allow → Allow
|
|
||||||
4. Workspace allow → Allow
|
|
||||||
5. Prompt user → Interactive approval
|
|
||||||
|
|
||||||
### Interactive Approval
|
|
||||||
|
|
||||||
When prompted, users can approve commands with different scopes:
|
|
||||||
- **Once** - Allow this one time only
|
|
||||||
- **Agent** - Always allow for this agent
|
|
||||||
- **Workspace** - Always allow for all agents
|
|
||||||
- **Deny** - Block this command
|
|
||||||
|
|
||||||
### Default Security
|
|
||||||
|
|
||||||
Denied by default:
|
|
||||||
- Sensitive files (`.env`, `.key`, `.pem`)
|
|
||||||
- Destructive commands (`rm -rf`, `sudo`)
|
|
||||||
- Operations outside the workspace
|
|
||||||
|
|
||||||
## Security Notice
|
|
||||||
|
|
||||||
This codebase has **known vulnerabilities** and issues with its dependencies. It will not be updated to new dependencies. Use for educational purposes only.
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
@@ -189,3 +55,27 @@ This project segment is licensed under the MIT License - see the [LICENSE](LICEN
|
|||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
Please refer to the [documentation](https://docs.agpt.co) for more detailed information about the project's architecture and concepts.
|
||||||
|
You can browse at the same point in time as this commit so the docs don't change.
|
||||||
|
|
||||||
|
## Historical Impact
|
||||||
|
|
||||||
|
AutoGPT Classic played a significant role in advancing the field of autonomous AI agents:
|
||||||
|
- Demonstrated practical implementation of AI autonomy
|
||||||
|
- Inspired numerous derivative projects and research
|
||||||
|
- Contributed to the development of AI agent architectures
|
||||||
|
- Helped identify key challenges in AI autonomy
|
||||||
|
|
||||||
|
## Security Notice
|
||||||
|
|
||||||
|
If you're studying this codebase, please understand this has KNOWN vulnerabilities and issues with its dependencies. It will not be updated to new dependencies.
|
||||||
|
|
||||||
|
## Community & Support
|
||||||
|
|
||||||
|
While active development has concluded:
|
||||||
|
- The codebase remains available for study and reference
|
||||||
|
- Historical discussions can be found in project issues
|
||||||
|
- Related research and developments continue in the broader AI agent community
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Thanks to all contributors who participated in this experimental project and helped advance the field of autonomous AI agents.
|
||||||
|
|||||||
27
classic/direct_benchmark/.gitignore
vendored
27
classic/direct_benchmark/.gitignore
vendored
@@ -1,27 +0,0 @@
|
|||||||
# Benchmark outputs
|
|
||||||
reports/
|
|
||||||
.benchmark_workspaces/
|
|
||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
*.egg-info/
|
|
||||||
.eggs/
|
|
||||||
dist/
|
|
||||||
build/
|
|
||||||
|
|
||||||
# Environment
|
|
||||||
.env
|
|
||||||
.venv/
|
|
||||||
venv/
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
*.swp
|
|
||||||
*.swo
|
|
||||||
|
|
||||||
# OS
|
|
||||||
.DS_Store
|
|
||||||
Thumbs.db
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
# CLAUDE.md - Direct Benchmark Harness
|
|
||||||
|
|
||||||
This file provides guidance to Claude Code when working with the direct benchmark harness.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The Direct Benchmark Harness is a high-performance testing framework for AutoGPT that directly instantiates agents without HTTP server overhead. It enables parallel execution of multiple strategy/model configurations.
|
|
||||||
|
|
||||||
## Quick Reference
|
|
||||||
|
|
||||||
All commands run from the `classic/` directory (parent of this directory):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install (one-time setup)
|
|
||||||
cd classic
|
|
||||||
poetry install
|
|
||||||
|
|
||||||
# Run benchmarks
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
|
|
||||||
# Run specific strategies and models
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo \
|
|
||||||
--models claude,openai \
|
|
||||||
--parallel 4
|
|
||||||
|
|
||||||
# Run a single test
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot \
|
|
||||||
--tests ReadFile
|
|
||||||
|
|
||||||
# List available challenges
|
|
||||||
poetry run direct-benchmark list-challenges
|
|
||||||
|
|
||||||
# List model presets
|
|
||||||
poetry run direct-benchmark list-models
|
|
||||||
|
|
||||||
# List strategies
|
|
||||||
poetry run direct-benchmark list-strategies
|
|
||||||
```
|
|
||||||
|
|
||||||
## CLI Options
|
|
||||||
|
|
||||||
### Run Command
|
|
||||||
|
|
||||||
| Option | Short | Description |
|
|
||||||
|--------|-------|-------------|
|
|
||||||
| `--strategies` | `-s` | Comma-separated strategies (one_shot, rewoo, plan_execute, reflexion, tree_of_thoughts) |
|
|
||||||
| `--models` | `-m` | Comma-separated model presets (claude, openai, etc.) |
|
|
||||||
| `--categories` | `-c` | Filter by challenge categories |
|
|
||||||
| `--skip-category` | `-S` | Exclude categories |
|
|
||||||
| `--tests` | `-t` | Filter by test names |
|
|
||||||
| `--attempts` | `-N` | Number of times to run each challenge |
|
|
||||||
| `--parallel` | `-p` | Maximum parallel runs (default: 4) |
|
|
||||||
| `--timeout` | | Per-challenge timeout in seconds (default: 300) |
|
|
||||||
| `--cutoff` | | Alias for --timeout |
|
|
||||||
| `--no-cutoff` | `--nc` | Disable time limit |
|
|
||||||
| `--max-steps` | | Maximum steps per challenge (default: 50) |
|
|
||||||
| `--maintain` | | Run only regression tests |
|
|
||||||
| `--improve` | | Run only non-regression tests |
|
|
||||||
| `--explore` | | Run only never-beaten challenges |
|
|
||||||
| `--no-dep` | | Ignore challenge dependencies |
|
|
||||||
| `--workspace` | | Workspace root directory |
|
|
||||||
| `--challenges-dir` | | Path to challenges directory |
|
|
||||||
| `--reports-dir` | | Path to reports directory |
|
|
||||||
| `--keep-answers` | | Keep answer files for debugging |
|
|
||||||
| `--quiet` | `-q` | Minimal output |
|
|
||||||
| `--verbose` | `-v` | Detailed per-challenge output |
|
|
||||||
| `--json` | | JSON output for CI/scripting |
|
|
||||||
| `--ci` | | CI mode: no live display, shows completion blocks (auto-enabled when CI env var is set or not a TTY) |
|
|
||||||
| `--fresh` | | Clear all saved state and start fresh (don't resume) |
|
|
||||||
| `--retry-failures` | | Re-run only the challenges that failed in previous run |
|
|
||||||
| `--reset-strategy` | | Reset saved results for specific strategy (can repeat) |
|
|
||||||
| `--reset-model` | | Reset saved results for specific model (can repeat) |
|
|
||||||
| `--reset-challenge` | | Reset saved results for specific challenge (can repeat) |
|
|
||||||
| `--debug` | | Enable debug output |
|
|
||||||
|
|
||||||
### State Management Commands
|
|
||||||
```bash
|
|
||||||
# Show current state
|
|
||||||
poetry run direct-benchmark state show
|
|
||||||
|
|
||||||
# Clear all state
|
|
||||||
poetry run direct-benchmark state clear
|
|
||||||
|
|
||||||
# Reset specific strategy/model/challenge
|
|
||||||
poetry run direct-benchmark state reset --strategy reflexion
|
|
||||||
poetry run direct-benchmark state reset --model claude-thinking-25k
|
|
||||||
poetry run direct-benchmark state reset --challenge ThreeSum
|
|
||||||
```
|
|
||||||
|
|
||||||
## Available Strategies
|
|
||||||
|
|
||||||
- `one_shot` - Single-pass reasoning (default)
|
|
||||||
- `rewoo` - Reasoning with observations
|
|
||||||
- `plan_execute` - Plan then execute
|
|
||||||
- `reflexion` - Self-reflection loop
|
|
||||||
- `tree_of_thoughts` - Multiple reasoning paths
|
|
||||||
|
|
||||||
## Available Model Presets
|
|
||||||
|
|
||||||
### Claude
|
|
||||||
- `claude` - sonnet-4 smart, haiku fast
|
|
||||||
- `claude-smart` - sonnet-4 for both
|
|
||||||
- `claude-fast` - haiku for both
|
|
||||||
- `claude-opus` - opus smart, sonnet fast
|
|
||||||
- `claude-opus-only` - opus for both
|
|
||||||
|
|
||||||
### Claude with Extended Thinking
|
|
||||||
- `claude-thinking-10k` - 10k thinking tokens
|
|
||||||
- `claude-thinking-25k` - 25k thinking tokens
|
|
||||||
- `claude-thinking-50k` - 50k thinking tokens
|
|
||||||
- `claude-opus-thinking` - opus with 25k thinking
|
|
||||||
- `claude-opus-thinking-50k` - opus with 50k thinking
|
|
||||||
|
|
||||||
### OpenAI
|
|
||||||
- `openai` - gpt-4o smart, gpt-4o-mini fast
|
|
||||||
- `openai-smart` - gpt-4o for both
|
|
||||||
- `openai-fast` - gpt-4o-mini for both
|
|
||||||
- `gpt5` - gpt-5 smart, gpt-4o fast
|
|
||||||
- `gpt5-only` - gpt-5 for both
|
|
||||||
|
|
||||||
### OpenAI Reasoning Models
|
|
||||||
- `o1`, `o1-mini` - o1 variants
|
|
||||||
- `o1-low`, `o1-medium`, `o1-high` - o1 with reasoning effort
|
|
||||||
- `o3-low`, `o3-medium`, `o3-high` - o3 with reasoning effort
|
|
||||||
- `gpt5-low`, `gpt5-medium`, `gpt5-high` - gpt-5 with reasoning effort
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
direct_benchmark/
|
|
||||||
├── pyproject.toml # Poetry config
|
|
||||||
├── README.md # User documentation
|
|
||||||
├── CLAUDE.md # This file
|
|
||||||
├── .gitignore
|
|
||||||
└── direct_benchmark/
|
|
||||||
├── __init__.py
|
|
||||||
├── __main__.py # CLI entry point
|
|
||||||
├── models.py # Pydantic models, presets
|
|
||||||
├── harness.py # Main orchestrator
|
|
||||||
├── runner.py # AgentRunner (single agent lifecycle)
|
|
||||||
├── parallel.py # ParallelExecutor (concurrent runs)
|
|
||||||
├── challenge_loader.py # Load challenges from JSON
|
|
||||||
├── evaluator.py # Evaluate outputs vs ground truth
|
|
||||||
├── report.py # Report generation
|
|
||||||
└── ui.py # Rich UI components
|
|
||||||
```
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Execution Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
CLI args → HarnessConfig
|
|
||||||
↓
|
|
||||||
BenchmarkHarness.run()
|
|
||||||
↓
|
|
||||||
ChallengeLoader.load_all() → list[Challenge]
|
|
||||||
↓
|
|
||||||
ParallelExecutor.execute_matrix(configs × challenges × attempts)
|
|
||||||
↓
|
|
||||||
[Parallel with semaphore limiting to N concurrent]
|
|
||||||
↓
|
|
||||||
AgentRunner.run_challenge():
|
|
||||||
1. Create temp workspace
|
|
||||||
2. Copy input artifacts to agent workspace
|
|
||||||
3. Create AppConfig with strategy/model
|
|
||||||
4. create_agent() - direct instantiation
|
|
||||||
5. Run agent loop until finish/timeout
|
|
||||||
6. Collect output files
|
|
||||||
↓
|
|
||||||
Evaluator.evaluate() - check against ground truth
|
|
||||||
↓
|
|
||||||
ReportGenerator - write reports
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key Components
|
|
||||||
|
|
||||||
**AgentRunner** (`runner.py`)
|
|
||||||
- Manages single agent lifecycle for one challenge
|
|
||||||
- Creates isolated temp workspace per run
|
|
||||||
- Copies input artifacts to `{workspace}/.autogpt/agents/{agent_id}/workspace/`
|
|
||||||
- Instantiates agent directly via `create_agent()`
|
|
||||||
- Runs agent loop: `propose_action()` → `execute()` until finish/timeout
|
|
||||||
|
|
||||||
**ParallelExecutor** (`parallel.py`)
|
|
||||||
- Manages concurrent execution with asyncio semaphore
|
|
||||||
- Supports multiple attempts per challenge
|
|
||||||
- Reports progress via callbacks
|
|
||||||
|
|
||||||
**Evaluator** (`evaluator.py`)
|
|
||||||
- String matching (should_contain/should_not_contain)
|
|
||||||
- Python script execution
|
|
||||||
- Pytest execution
|
|
||||||
|
|
||||||
**ReportGenerator** (`report.py`)
|
|
||||||
- Per-config `report.json` files (compatible with agbenchmark format)
|
|
||||||
- Comparison reports across all configs
|
|
||||||
|
|
||||||
## Report Format
|
|
||||||
|
|
||||||
Reports are generated in `./reports/` with format:
|
|
||||||
```
|
|
||||||
reports/
|
|
||||||
├── {timestamp}_{strategy}_{model}/
|
|
||||||
│ └── report.json
|
|
||||||
└── strategy_comparison_{timestamp}.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
|
|
||||||
- `autogpt-forge` - Core agent framework
|
|
||||||
- `autogpt` - Original AutoGPT agent
|
|
||||||
- `click` - CLI framework
|
|
||||||
- `pydantic` - Data models
|
|
||||||
- `rich` - Terminal UI
|
|
||||||
|
|
||||||
## Key Differences from agbenchmark
|
|
||||||
|
|
||||||
| agbenchmark | direct_benchmark |
|
|
||||||
|-------------|-----------------|
|
|
||||||
| `subprocess.Popen` + HTTP server | Direct `create_agent()` |
|
|
||||||
| HTTP/REST via Agent Protocol | Direct `propose_action()`/`execute()` |
|
|
||||||
| Sequential (one config at a time) | Parallel via asyncio semaphore |
|
|
||||||
| Port-based isolation | Workspace-based isolation |
|
|
||||||
| `agbenchmark run` CLI | Direct JSON parsing |
|
|
||||||
|
|
||||||
## Common Tasks
|
|
||||||
|
|
||||||
### Run Full Benchmark Suite
|
|
||||||
```bash
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo,plan_execute \
|
|
||||||
--models claude \
|
|
||||||
--parallel 8
|
|
||||||
```
|
|
||||||
|
|
||||||
### Compare Strategies
|
|
||||||
```bash
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo,plan_execute,reflexion \
|
|
||||||
--models claude \
|
|
||||||
--tests ReadFile,WriteFile,ThreeSum
|
|
||||||
```
|
|
||||||
|
|
||||||
### Debug a Failing Test
|
|
||||||
```bash
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot \
|
|
||||||
--tests FailingTest \
|
|
||||||
--keep-answers \
|
|
||||||
--verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
### Resume / Incremental Runs
|
|
||||||
The benchmark automatically saves progress and resumes from where it left off.
|
|
||||||
State is saved to `.benchmark_state.json` in the reports directory.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run benchmarks - will resume from last run automatically
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,reflexion \
|
|
||||||
--models claude
|
|
||||||
|
|
||||||
# Start fresh (clear all saved state)
|
|
||||||
poetry run direct-benchmark run --fresh \
|
|
||||||
--strategies one_shot,reflexion \
|
|
||||||
--models claude
|
|
||||||
|
|
||||||
# Reset specific strategy and re-run
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--reset-strategy reflexion \
|
|
||||||
--strategies one_shot,reflexion \
|
|
||||||
--models claude
|
|
||||||
|
|
||||||
# Reset specific model and re-run
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--reset-model claude-thinking-25k \
|
|
||||||
--strategies one_shot \
|
|
||||||
--models claude,claude-thinking-25k
|
|
||||||
|
|
||||||
# Retry only the failures from the last run
|
|
||||||
poetry run direct-benchmark run --retry-failures \
|
|
||||||
--strategies one_shot,reflexion \
|
|
||||||
--models claude
|
|
||||||
```
|
|
||||||
|
|
||||||
### CI/Scripting Mode
|
|
||||||
```bash
|
|
||||||
# JSON output (parseable)
|
|
||||||
poetry run direct-benchmark run --json
|
|
||||||
|
|
||||||
# CI mode - shows completion blocks without Live display
|
|
||||||
# Auto-enabled when CI=true env var is set or stdout is not a TTY
|
|
||||||
poetry run direct-benchmark run --ci
|
|
||||||
```
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
# Direct Benchmark Harness
|
|
||||||
|
|
||||||
High-performance benchmark harness for AutoGPT that directly instantiates agents without HTTP server overhead, enabling parallel execution of multiple configurations.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- **Direct Agent Instantiation**: No HTTP server, no Agent Protocol overhead
|
|
||||||
- **Parallel Execution**: Run multiple strategy/model combinations concurrently
|
|
||||||
- **Multiple Attempts**: Run each challenge multiple times for statistical reliability
|
|
||||||
- **Rich UI**: Live progress display with Rich library
|
|
||||||
- **Multiple Output Modes**: Default (rich), quiet, verbose, JSON for CI
|
|
||||||
- **Full CLI Compatibility**: All flags from the original agbenchmark supported
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
All commands run from the `classic/` directory (parent of this directory):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd classic
|
|
||||||
poetry install
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run benchmarks with default settings
|
|
||||||
poetry run direct-benchmark run
|
|
||||||
|
|
||||||
# Run specific strategies and models
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot,rewoo \
|
|
||||||
--models claude,openai \
|
|
||||||
--parallel 4
|
|
||||||
|
|
||||||
# Run a single test
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot \
|
|
||||||
--tests ReadFile
|
|
||||||
|
|
||||||
# Run multiple attempts per challenge
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--strategies one_shot \
|
|
||||||
--attempts 3
|
|
||||||
|
|
||||||
# Run only regression tests (previously beaten)
|
|
||||||
poetry run direct-benchmark run --maintain
|
|
||||||
|
|
||||||
# Run only non-regression tests (not consistently beaten)
|
|
||||||
poetry run direct-benchmark run --improve
|
|
||||||
|
|
||||||
# Run only never-beaten challenges
|
|
||||||
poetry run direct-benchmark run --explore
|
|
||||||
|
|
||||||
# List available challenges
|
|
||||||
poetry run direct-benchmark list-challenges
|
|
||||||
|
|
||||||
# List model presets
|
|
||||||
poetry run direct-benchmark list-models
|
|
||||||
|
|
||||||
# List strategies
|
|
||||||
poetry run direct-benchmark list-strategies
|
|
||||||
```
|
|
||||||
|
|
||||||
## CLI Options
|
|
||||||
|
|
||||||
### Challenge Selection
|
|
||||||
- `--strategies, -s`: Comma-separated strategies (one_shot, rewoo, plan_execute, reflexion, tree_of_thoughts)
|
|
||||||
- `--models, -m`: Comma-separated model presets (claude, openai, etc.)
|
|
||||||
- `--categories, -c`: Filter by challenge categories
|
|
||||||
- `--skip-category, -S`: Exclude categories
|
|
||||||
- `--tests, -t`: Filter by test names
|
|
||||||
|
|
||||||
### Execution Control
|
|
||||||
- `--attempts, -N`: Number of times to run each challenge
|
|
||||||
- `--parallel, -p`: Maximum parallel runs (default: 4)
|
|
||||||
- `--timeout`: Per-challenge timeout in seconds (default: 300)
|
|
||||||
- `--cutoff`: Alias for --timeout
|
|
||||||
- `--no-cutoff, --nc`: Disable time limit
|
|
||||||
- `--max-steps`: Maximum steps per challenge (default: 50)
|
|
||||||
|
|
||||||
### Challenge Filtering Modes
|
|
||||||
- `--maintain`: Run only regression tests (previously beaten consistently)
|
|
||||||
- `--improve`: Run only non-regression tests (not consistently beaten)
|
|
||||||
- `--explore`: Run only challenges that have never been beaten
|
|
||||||
- `--no-dep`: Run all challenges regardless of dependency success/failure
|
|
||||||
|
|
||||||
### Output & Debug
|
|
||||||
- `--quiet, -q`: Minimal output
|
|
||||||
- `--verbose, -v`: Detailed per-challenge output
|
|
||||||
- `--json`: JSON output for CI/scripting
|
|
||||||
- `--debug`: Enable debug output
|
|
||||||
- `--keep-answers`: Keep answer files for debugging
|
|
||||||
|
|
||||||
### Paths
|
|
||||||
- `--workspace`: Workspace root directory
|
|
||||||
- `--challenges-dir`: Path to challenges directory
|
|
||||||
- `--reports-dir`: Path to reports directory
|
|
||||||
|
|
||||||
## Available Strategies
|
|
||||||
|
|
||||||
| Strategy | Description |
|
|
||||||
|----------|-------------|
|
|
||||||
| `one_shot` | Single-pass reasoning (default, most reliable) |
|
|
||||||
| `rewoo` | Reasoning with observations |
|
|
||||||
| `plan_execute` | Plan then execute |
|
|
||||||
| `reflexion` | Self-reflection loop |
|
|
||||||
| `tree_of_thoughts` | Multiple reasoning paths |
|
|
||||||
|
|
||||||
## Available Model Presets
|
|
||||||
|
|
||||||
### Claude
|
|
||||||
- `claude`: sonnet-4 smart, haiku fast (default)
|
|
||||||
- `claude-smart`: sonnet-4 for both
|
|
||||||
- `claude-fast`: haiku for both
|
|
||||||
- `claude-opus`: opus smart, sonnet fast
|
|
||||||
- `claude-opus-only`: opus for both
|
|
||||||
|
|
||||||
### Claude with Extended Thinking
|
|
||||||
- `claude-thinking-10k`: 10k thinking tokens
|
|
||||||
- `claude-thinking-25k`: 25k thinking tokens
|
|
||||||
- `claude-thinking-50k`: 50k thinking tokens
|
|
||||||
- `claude-opus-thinking`: opus with 25k thinking
|
|
||||||
- `claude-opus-thinking-50k`: opus with 50k thinking
|
|
||||||
|
|
||||||
### OpenAI
|
|
||||||
- `openai`: gpt-4o smart, gpt-4o-mini fast
|
|
||||||
- `openai-smart`: gpt-4o for both
|
|
||||||
- `openai-fast`: gpt-4o-mini for both
|
|
||||||
- `gpt5`: gpt-5 smart, gpt-4o fast
|
|
||||||
- `gpt5-only`: gpt-5 for both
|
|
||||||
|
|
||||||
### OpenAI Reasoning Models
|
|
||||||
- `o1`, `o1-mini`: o1 variants
|
|
||||||
- `o1-low`, `o1-medium`, `o1-high`: o1 with reasoning effort
|
|
||||||
- `o3-low`, `o3-medium`, `o3-high`: o3 with reasoning effort
|
|
||||||
|
|
||||||
## Reports
|
|
||||||
|
|
||||||
Reports are generated in `./reports/` with format:
|
|
||||||
```
|
|
||||||
reports/
|
|
||||||
├── {timestamp}_{strategy}_{model}/
|
|
||||||
│ └── report.json
|
|
||||||
└── strategy_comparison_{timestamp}.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key Differences from agbenchmark
|
|
||||||
|
|
||||||
| agbenchmark | direct_benchmark |
|
|
||||||
|-------------|------------------|
|
|
||||||
| `subprocess.Popen` + HTTP server | Direct `create_agent()` |
|
|
||||||
| HTTP/REST via Agent Protocol | Direct `propose_action()`/`execute()` |
|
|
||||||
| Sequential (one config at a time) | Parallel via asyncio semaphore |
|
|
||||||
| Port-based isolation | Workspace-based isolation |
|
|
||||||
@@ -1,842 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Strategy Failure Analysis Tool
|
|
||||||
|
|
||||||
Analyzes why prompt strategies fail on benchmark tests, identifies patterns,
|
|
||||||
and provides actionable insights for improvement.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Full analysis with LLM summaries (default)
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py
|
|
||||||
|
|
||||||
# Disable LLM analysis (just print raw pattern data)
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py --no-analysis
|
|
||||||
|
|
||||||
# Focus on specific strategy
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py --strategy rewoo
|
|
||||||
|
|
||||||
# Compare one test across strategies (interactive)
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py --test Battleship
|
|
||||||
|
|
||||||
# Interactive drill-down mode
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py --interactive
|
|
||||||
|
|
||||||
# Export to markdown
|
|
||||||
poetry run python agbenchmark_config/analyze_failures.py --markdown
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from collections import Counter, defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
# Type hints for optional rich imports
|
|
||||||
Console: Any = None
|
|
||||||
Markdown: Any = None
|
|
||||||
Panel: Any = None
|
|
||||||
Progress: Any = None
|
|
||||||
SpinnerColumn: Any = None
|
|
||||||
TextColumn: Any = None
|
|
||||||
Confirm: Any = None
|
|
||||||
Prompt: Any = None
|
|
||||||
Table: Any = None
|
|
||||||
Text: Any = None
|
|
||||||
Tree: Any = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.markdown import Markdown # noqa: F401
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
||||||
from rich.prompt import Confirm, Prompt # noqa: F401
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
from rich.tree import Tree
|
|
||||||
|
|
||||||
RICH_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
RICH_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class FailurePattern(Enum):
|
|
||||||
"""Categories of failure patterns."""
|
|
||||||
|
|
||||||
OVER_PLANNING = "over_planning" # Too many planning steps, not enough execution
|
|
||||||
TOOL_LOOP = "tool_loop" # Repeating same tool without progress
|
|
||||||
MISSING_CRITICAL = "missing_critical" # Didn't complete key action
|
|
||||||
TIMEOUT = "timeout" # Hit step limit before completion
|
|
||||||
ERROR_UNRECOVERED = "error_unrecovered" # Hit error and couldn't recover
|
|
||||||
WRONG_APPROACH = "wrong_approach" # Fundamentally wrong solution
|
|
||||||
UNKNOWN = "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StepInfo:
|
|
||||||
"""Information about a single execution step."""
|
|
||||||
|
|
||||||
step_num: int
|
|
||||||
tool_name: str
|
|
||||||
tool_args: dict
|
|
||||||
tool_result: Optional[dict]
|
|
||||||
thoughts: dict
|
|
||||||
cumulative_cost: float
|
|
||||||
output: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TestResult:
|
|
||||||
"""Analysis of a single test execution."""
|
|
||||||
|
|
||||||
test_name: str
|
|
||||||
strategy: str
|
|
||||||
task: str
|
|
||||||
success: bool
|
|
||||||
fail_reason: Optional[str]
|
|
||||||
reached_cutoff: bool
|
|
||||||
n_steps: int
|
|
||||||
steps: list[StepInfo]
|
|
||||||
total_cost: float
|
|
||||||
run_time: str
|
|
||||||
tool_distribution: Counter = field(default_factory=Counter)
|
|
||||||
patterns_detected: list[FailurePattern] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StrategyAnalysis:
|
|
||||||
"""Analysis results for a strategy."""
|
|
||||||
|
|
||||||
strategy_name: str
|
|
||||||
total_tests: int
|
|
||||||
passed: int
|
|
||||||
failed: int
|
|
||||||
success_rate: float
|
|
||||||
total_cost: float
|
|
||||||
avg_steps: float
|
|
||||||
failed_tests: list[TestResult]
|
|
||||||
pattern_distribution: Counter = field(default_factory=Counter)
|
|
||||||
|
|
||||||
|
|
||||||
class FailureAnalyzer:
|
|
||||||
"""Main analysis engine."""
|
|
||||||
|
|
||||||
def __init__(self, reports_dir: Path, use_llm: bool = True):
|
|
||||||
self.reports_dir = reports_dir
|
|
||||||
self.use_llm = use_llm
|
|
||||||
self._console_instance = Console() if RICH_AVAILABLE else None
|
|
||||||
self.strategies: dict[str, StrategyAnalysis] = {}
|
|
||||||
self.test_comparison: dict[str, dict[str, TestResult]] = defaultdict(dict)
|
|
||||||
self._llm_provider = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def console(self) -> Any:
|
|
||||||
"""Get console instance (only call when RICH_AVAILABLE is True)."""
|
|
||||||
assert self._console_instance is not None
|
|
||||||
return self._console_instance
|
|
||||||
|
|
||||||
def _print(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
"""Print with Rich if available, otherwise standard print."""
|
|
||||||
if self._console_instance:
|
|
||||||
self._console_instance.print(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|
||||||
def find_reports(self) -> list[tuple[str, Path]]:
|
|
||||||
"""Find all strategy-specific reports."""
|
|
||||||
reports = []
|
|
||||||
for report_dir in self.reports_dir.iterdir():
|
|
||||||
if not report_dir.is_dir():
|
|
||||||
continue
|
|
||||||
report_file = report_dir / "report.json"
|
|
||||||
if not report_file.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract strategy from directory name
|
|
||||||
name = report_dir.name
|
|
||||||
strategy = None
|
|
||||||
for s in [
|
|
||||||
"one_shot",
|
|
||||||
"rewoo",
|
|
||||||
"plan_execute",
|
|
||||||
"reflexion",
|
|
||||||
"tree_of_thoughts",
|
|
||||||
]:
|
|
||||||
if s in name:
|
|
||||||
strategy = s
|
|
||||||
break
|
|
||||||
|
|
||||||
if strategy:
|
|
||||||
reports.append((strategy, report_file))
|
|
||||||
|
|
||||||
return sorted(reports, key=lambda x: x[1].stat().st_mtime, reverse=True)
|
|
||||||
|
|
||||||
def parse_report(self, strategy: str, report_path: Path) -> StrategyAnalysis:
|
|
||||||
"""Parse a benchmark report file."""
|
|
||||||
with open(report_path) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
tests_data = data.get("tests", {})
|
|
||||||
failed_tests = []
|
|
||||||
total_cost = 0.0
|
|
||||||
total_steps = 0
|
|
||||||
passed = 0
|
|
||||||
failed = 0
|
|
||||||
|
|
||||||
for test_name, test_data in tests_data.items():
|
|
||||||
results = test_data.get("results", [])
|
|
||||||
if not results:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result = results[0]
|
|
||||||
success = result.get("success", False)
|
|
||||||
n_steps = result.get("n_steps", 0)
|
|
||||||
cost = result.get("cost", 0)
|
|
||||||
|
|
||||||
total_steps += n_steps
|
|
||||||
total_cost += cost or 0
|
|
||||||
|
|
||||||
if success:
|
|
||||||
passed += 1
|
|
||||||
else:
|
|
||||||
failed += 1
|
|
||||||
test_result = self._parse_test_result(
|
|
||||||
test_name, strategy, test_data, result
|
|
||||||
)
|
|
||||||
failed_tests.append(test_result)
|
|
||||||
self.test_comparison[test_name][strategy] = test_result
|
|
||||||
|
|
||||||
total_tests = passed + failed
|
|
||||||
return StrategyAnalysis(
|
|
||||||
strategy_name=strategy,
|
|
||||||
total_tests=total_tests,
|
|
||||||
passed=passed,
|
|
||||||
failed=failed,
|
|
||||||
success_rate=(passed / total_tests * 100) if total_tests > 0 else 0,
|
|
||||||
total_cost=total_cost,
|
|
||||||
avg_steps=total_steps / total_tests if total_tests > 0 else 0,
|
|
||||||
failed_tests=failed_tests,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_test_result(
|
|
||||||
self, test_name: str, strategy: str, test_data: dict, result: dict
|
|
||||||
) -> TestResult:
|
|
||||||
"""Parse a single test result."""
|
|
||||||
steps_data = result.get("steps", [])
|
|
||||||
steps = []
|
|
||||||
tool_distribution = Counter()
|
|
||||||
|
|
||||||
for i, step in enumerate(steps_data):
|
|
||||||
ao = step.get("additional_output") or {}
|
|
||||||
use_tool = ao.get("use_tool") or {}
|
|
||||||
last_action = ao.get("last_action") or {}
|
|
||||||
thoughts = ao.get("thoughts") or {}
|
|
||||||
|
|
||||||
tool_name = use_tool.get("name", "none")
|
|
||||||
tool_distribution[tool_name] += 1
|
|
||||||
|
|
||||||
step_info = StepInfo(
|
|
||||||
step_num=i + 1,
|
|
||||||
tool_name=tool_name,
|
|
||||||
tool_args=use_tool.get("arguments", {}),
|
|
||||||
tool_result=last_action.get("result") if last_action else None,
|
|
||||||
thoughts=thoughts,
|
|
||||||
cumulative_cost=ao.get("task_cumulative_cost", 0),
|
|
||||||
output=step.get("output", ""),
|
|
||||||
)
|
|
||||||
steps.append(step_info)
|
|
||||||
|
|
||||||
test_result = TestResult(
|
|
||||||
test_name=test_name,
|
|
||||||
strategy=strategy,
|
|
||||||
task=test_data.get("task", ""),
|
|
||||||
success=False,
|
|
||||||
fail_reason=result.get("fail_reason"),
|
|
||||||
reached_cutoff=result.get("reached_cutoff", False),
|
|
||||||
n_steps=result.get("n_steps", 0),
|
|
||||||
steps=steps,
|
|
||||||
total_cost=result.get("cost", 0),
|
|
||||||
run_time=result.get("run_time", ""),
|
|
||||||
tool_distribution=tool_distribution,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detect patterns
|
|
||||||
test_result.patterns_detected = self._detect_patterns(test_result)
|
|
||||||
return test_result
|
|
||||||
|
|
||||||
def _detect_patterns(self, test: TestResult) -> list[FailurePattern]:
|
|
||||||
"""Detect failure patterns in a test result."""
|
|
||||||
patterns = []
|
|
||||||
|
|
||||||
# Pattern 1: Over-planning
|
|
||||||
planning_tools = {"todo_write", "todo_read", "think", "plan"}
|
|
||||||
execution_tools = {
|
|
||||||
"write_file",
|
|
||||||
"execute_python",
|
|
||||||
"execute_shell",
|
|
||||||
"read_file",
|
|
||||||
}
|
|
||||||
|
|
||||||
planning_count = sum(test.tool_distribution.get(t, 0) for t in planning_tools)
|
|
||||||
_execution_count = sum( # noqa: F841
|
|
||||||
test.tool_distribution.get(t, 0) for t in execution_tools
|
|
||||||
)
|
|
||||||
|
|
||||||
if test.n_steps > 0:
|
|
||||||
planning_ratio = planning_count / test.n_steps
|
|
||||||
if planning_ratio > 0.5 and test.n_steps > 1:
|
|
||||||
patterns.append(FailurePattern.OVER_PLANNING)
|
|
||||||
|
|
||||||
# Pattern 2: Tool loops (same tool used 3+ times consecutively)
|
|
||||||
if len(test.steps) >= 3:
|
|
||||||
for i in range(len(test.steps) - 2):
|
|
||||||
if (
|
|
||||||
test.steps[i].tool_name
|
|
||||||
== test.steps[i + 1].tool_name
|
|
||||||
== test.steps[i + 2].tool_name
|
|
||||||
):
|
|
||||||
patterns.append(FailurePattern.TOOL_LOOP)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Pattern 3: Missing critical action
|
|
||||||
# If task mentions "write" or "create" but no write_file was used
|
|
||||||
task_lower = test.task.lower()
|
|
||||||
if any(word in task_lower for word in ["write", "create", "generate", "build"]):
|
|
||||||
if test.tool_distribution.get("write_file", 0) == 0:
|
|
||||||
patterns.append(FailurePattern.MISSING_CRITICAL)
|
|
||||||
|
|
||||||
# Pattern 4: Timeout
|
|
||||||
if test.reached_cutoff:
|
|
||||||
patterns.append(FailurePattern.TIMEOUT)
|
|
||||||
|
|
||||||
# Pattern 5: Error unrecovered
|
|
||||||
error_count = 0
|
|
||||||
for step in test.steps:
|
|
||||||
if step.tool_result and step.tool_result.get("status") == "error":
|
|
||||||
error_count += 1
|
|
||||||
if error_count > 0 and error_count == len(test.steps) - 1:
|
|
||||||
patterns.append(FailurePattern.ERROR_UNRECOVERED)
|
|
||||||
|
|
||||||
if not patterns:
|
|
||||||
patterns.append(FailurePattern.UNKNOWN)
|
|
||||||
|
|
||||||
return patterns
|
|
||||||
|
|
||||||
def analyze_all(self) -> None:
|
|
||||||
"""Analyze all available reports."""
|
|
||||||
reports = self.find_reports()
|
|
||||||
|
|
||||||
# Keep only most recent report per strategy
|
|
||||||
latest_reports = {}
|
|
||||||
for strategy, path in reports:
|
|
||||||
if strategy not in latest_reports:
|
|
||||||
latest_reports[strategy] = path
|
|
||||||
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
with Progress(
|
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
|
||||||
console=self.console,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task(
|
|
||||||
"Analyzing reports...", total=len(latest_reports)
|
|
||||||
)
|
|
||||||
for strategy, path in latest_reports.items():
|
|
||||||
progress.update(task, description=f"Analyzing {strategy}...")
|
|
||||||
self.strategies[strategy] = self.parse_report(strategy, path)
|
|
||||||
progress.advance(task)
|
|
||||||
else:
|
|
||||||
for strategy, path in latest_reports.items():
|
|
||||||
print(f"Analyzing {strategy}...")
|
|
||||||
self.strategies[strategy] = self.parse_report(strategy, path)
|
|
||||||
|
|
||||||
def _get_llm_provider(self) -> Any:
|
|
||||||
"""Lazy-load the LLM provider."""
|
|
||||||
if self._llm_provider is None:
|
|
||||||
try:
|
|
||||||
# Add parent paths to find forge
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "forge"))
|
|
||||||
from forge.llm.providers import MultiProvider
|
|
||||||
|
|
||||||
self._llm_provider = MultiProvider()
|
|
||||||
except ImportError as e:
|
|
||||||
self._print(
|
|
||||||
f"[yellow]Warning: Could not load LLM provider: {e}[/yellow]"
|
|
||||||
if RICH_AVAILABLE
|
|
||||||
else f"Warning: Could not load LLM provider: {e}"
|
|
||||||
)
|
|
||||||
self._llm_provider = False
|
|
||||||
return self._llm_provider if self._llm_provider else None
|
|
||||||
|
|
||||||
async def _get_llm_analysis(self, test: TestResult) -> Optional[str]:
|
|
||||||
"""Get LLM-powered analysis of a failure.
|
|
||||||
|
|
||||||
Note: This is a placeholder for future LLM-powered analysis.
|
|
||||||
Currently disabled to avoid dependency issues.
|
|
||||||
"""
|
|
||||||
# LLM analysis disabled for now - patterns provide sufficient insights
|
|
||||||
return None
|
|
||||||
|
|
||||||
def print_summary(self) -> None:
|
|
||||||
"""Print overall summary."""
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
table = Table(title="Strategy Comparison Summary")
|
|
||||||
table.add_column("Strategy", style="cyan")
|
|
||||||
table.add_column("Tests", justify="right")
|
|
||||||
table.add_column("Passed", justify="right", style="green")
|
|
||||||
table.add_column("Failed", justify="right", style="red")
|
|
||||||
table.add_column("Success %", justify="right")
|
|
||||||
table.add_column("Avg Steps", justify="right")
|
|
||||||
table.add_column("Cost", justify="right")
|
|
||||||
|
|
||||||
for name, analysis in sorted(
|
|
||||||
self.strategies.items(), key=lambda x: x[1].success_rate, reverse=True
|
|
||||||
):
|
|
||||||
table.add_row(
|
|
||||||
name,
|
|
||||||
str(analysis.total_tests),
|
|
||||||
str(analysis.passed),
|
|
||||||
str(analysis.failed),
|
|
||||||
f"{analysis.success_rate:.1f}%",
|
|
||||||
f"{analysis.avg_steps:.1f}",
|
|
||||||
f"${analysis.total_cost:.4f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.console.print(table)
|
|
||||||
else:
|
|
||||||
print("\n=== Strategy Comparison Summary ===")
|
|
||||||
hdr = (
|
|
||||||
f"{'Strategy':<20} {'Tests':>6} {'Passed':>7} "
|
|
||||||
f"{'Failed':>7} {'Success%':>10} {'AvgSteps':>9} {'Cost':>10}"
|
|
||||||
)
|
|
||||||
print(hdr)
|
|
||||||
print("-" * 80)
|
|
||||||
for name, analysis in sorted(
|
|
||||||
self.strategies.items(), key=lambda x: x[1].success_rate, reverse=True
|
|
||||||
):
|
|
||||||
row = (
|
|
||||||
f"{name:<20} {analysis.total_tests:>6} "
|
|
||||||
f"{analysis.passed:>7} {analysis.failed:>7} "
|
|
||||||
f"{analysis.success_rate:>9.1f}% {analysis.avg_steps:>9.1f} "
|
|
||||||
f"${analysis.total_cost:>9.4f}"
|
|
||||||
)
|
|
||||||
print(row)
|
|
||||||
|
|
||||||
def print_pattern_analysis(self) -> None:
|
|
||||||
"""Print failure pattern analysis."""
|
|
||||||
all_patterns = Counter()
|
|
||||||
for analysis in self.strategies.values():
|
|
||||||
for test in analysis.failed_tests:
|
|
||||||
for pattern in test.patterns_detected:
|
|
||||||
all_patterns[pattern] += 1
|
|
||||||
|
|
||||||
self._print("\n")
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
table = Table(title="Failure Pattern Distribution")
|
|
||||||
table.add_column("Pattern", style="yellow")
|
|
||||||
table.add_column("Count", justify="right")
|
|
||||||
table.add_column("Description")
|
|
||||||
|
|
||||||
pattern_descriptions = {
|
|
||||||
FailurePattern.OVER_PLANNING: "Too much planning, not enough action",
|
|
||||||
FailurePattern.TOOL_LOOP: "Repeats same tool 3+ times consecutively",
|
|
||||||
FailurePattern.MISSING_CRITICAL: "Never performed key action",
|
|
||||||
FailurePattern.TIMEOUT: "Hit step limit before completing task",
|
|
||||||
FailurePattern.ERROR_UNRECOVERED: "Hit errors and couldn't recover",
|
|
||||||
FailurePattern.WRONG_APPROACH: "Took fundamentally wrong approach",
|
|
||||||
FailurePattern.UNKNOWN: "Pattern not categorized",
|
|
||||||
}
|
|
||||||
|
|
||||||
for pattern, count in all_patterns.most_common():
|
|
||||||
table.add_row(
|
|
||||||
pattern.value, str(count), pattern_descriptions.get(pattern, "")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.console.print(table)
|
|
||||||
else:
|
|
||||||
print("\n=== Failure Pattern Distribution ===")
|
|
||||||
for pattern, count in all_patterns.most_common():
|
|
||||||
print(f" {pattern.value}: {count}")
|
|
||||||
|
|
||||||
def print_failed_tests(self, strategy: Optional[str] = None) -> None:
|
|
||||||
"""Print detailed failure analysis."""
|
|
||||||
strategies_to_show = (
|
|
||||||
[self.strategies[strategy]] if strategy else self.strategies.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
for analysis in strategies_to_show:
|
|
||||||
self._print("\n")
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
msg = (
|
|
||||||
f"[bold]{analysis.strategy_name}[/bold] - "
|
|
||||||
f"{analysis.failed} failures out of {analysis.total_tests} tests"
|
|
||||||
)
|
|
||||||
self.console.print(Panel(msg, title="Strategy Analysis"))
|
|
||||||
else:
|
|
||||||
print(f"\n=== {analysis.strategy_name} ===")
|
|
||||||
print(f"Failures: {analysis.failed}/{analysis.total_tests}")
|
|
||||||
|
|
||||||
for test in analysis.failed_tests:
|
|
||||||
self._print_test_failure(test)
|
|
||||||
|
|
||||||
def _print_test_failure(self, test: TestResult) -> None:
|
|
||||||
"""Print a single test failure."""
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
tree = Tree(f"[red]{test.test_name}[/red]")
|
|
||||||
tree.add(f"[dim]Task:[/dim] {test.task[:80]}...")
|
|
||||||
tree.add(f"[dim]Steps:[/dim] {test.n_steps}")
|
|
||||||
tree.add(f"[dim]Cost:[/dim] ${test.total_cost:.4f}")
|
|
||||||
patterns = ", ".join(p.value for p in test.patterns_detected)
|
|
||||||
tree.add(f"[dim]Patterns:[/dim] {patterns}")
|
|
||||||
|
|
||||||
tools = tree.add("[dim]Tool sequence:[/dim]")
|
|
||||||
tool_seq = [s.tool_name for s in test.steps[:10]]
|
|
||||||
tools.add(" -> ".join(tool_seq) + ("..." if len(test.steps) > 10 else ""))
|
|
||||||
|
|
||||||
if test.fail_reason:
|
|
||||||
reason = tree.add("[dim]Fail reason:[/dim]")
|
|
||||||
reason.add(Text(test.fail_reason[:200], style="red"))
|
|
||||||
|
|
||||||
self.console.print(tree)
|
|
||||||
else:
|
|
||||||
print(f"\n {test.test_name}")
|
|
||||||
print(f" Task: {test.task[:80]}...")
|
|
||||||
print(f" Steps: {test.n_steps}, Cost: ${test.total_cost:.4f}")
|
|
||||||
print(f" Patterns: {', '.join(p.value for p in test.patterns_detected)}")
|
|
||||||
tool_seq = [s.tool_name for s in test.steps[:10]]
|
|
||||||
print(f" Tools: {' -> '.join(tool_seq)}")
|
|
||||||
if test.fail_reason:
|
|
||||||
print(f" Fail reason: {test.fail_reason[:200]}")
|
|
||||||
|
|
||||||
def compare_test(self, test_name: str) -> None:
|
|
||||||
"""Compare a single test across all strategies."""
|
|
||||||
if test_name not in self.test_comparison:
|
|
||||||
self._print(
|
|
||||||
f"[red]Test '{test_name}' not found in failed tests[/red]"
|
|
||||||
if RICH_AVAILABLE
|
|
||||||
else f"Test '{test_name}' not found in failed tests"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
results = self.test_comparison[test_name]
|
|
||||||
self._print("\n")
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
self.console.print(Panel(f"[bold]Comparing: {test_name}[/bold]"))
|
|
||||||
else:
|
|
||||||
print(f"\n=== Comparing: {test_name} ===")
|
|
||||||
|
|
||||||
for strategy, test in sorted(results.items()):
|
|
||||||
self._print("\n")
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
self.console.print(f"[cyan]--- {strategy} ---[/cyan]")
|
|
||||||
else:
|
|
||||||
print(f"\n--- {strategy} ---")
|
|
||||||
self._print_test_failure(test)
|
|
||||||
|
|
||||||
def interactive_mode(self) -> None:
|
|
||||||
"""Run interactive exploration mode."""
|
|
||||||
if not RICH_AVAILABLE:
|
|
||||||
print("Interactive mode requires the 'rich' library.")
|
|
||||||
print("Install with: pip install rich")
|
|
||||||
return
|
|
||||||
|
|
||||||
while True:
|
|
||||||
self.console.print("\n[bold]Interactive Failure Analysis[/bold]")
|
|
||||||
self.console.print("Commands:")
|
|
||||||
self.console.print(" [cyan]summary[/cyan] - Show overall summary")
|
|
||||||
self.console.print(" [cyan]patterns[/cyan] - Show pattern analysis")
|
|
||||||
self.console.print(
|
|
||||||
" [cyan]strategy <name>[/cyan] - Show failures for a strategy"
|
|
||||||
)
|
|
||||||
self.console.print(
|
|
||||||
" [cyan]test <name>[/cyan] - Compare test across strategies"
|
|
||||||
)
|
|
||||||
self.console.print(
|
|
||||||
" [cyan]step <strategy> <test> <n>[/cyan] - Show step details"
|
|
||||||
)
|
|
||||||
self.console.print(" [cyan]list tests[/cyan] - List all failed tests")
|
|
||||||
self.console.print(" [cyan]list strategies[/cyan] - List strategies")
|
|
||||||
self.console.print(" [cyan]quit[/cyan] - Exit")
|
|
||||||
|
|
||||||
cmd = Prompt.ask("\n[bold]>>[/bold]").strip().lower()
|
|
||||||
|
|
||||||
if cmd == "quit" or cmd == "q":
|
|
||||||
break
|
|
||||||
elif cmd == "summary":
|
|
||||||
self.print_summary()
|
|
||||||
elif cmd == "patterns":
|
|
||||||
self.print_pattern_analysis()
|
|
||||||
elif cmd.startswith("strategy "):
|
|
||||||
strategy = cmd.split(" ", 1)[1]
|
|
||||||
if strategy in self.strategies:
|
|
||||||
self.print_failed_tests(strategy)
|
|
||||||
else:
|
|
||||||
self.console.print(f"[red]Unknown strategy: {strategy}[/red]")
|
|
||||||
elif cmd.startswith("test "):
|
|
||||||
test_name = cmd.split(" ", 1)[1]
|
|
||||||
self.compare_test(test_name)
|
|
||||||
elif cmd.startswith("step "):
|
|
||||||
parts = cmd.split()
|
|
||||||
if len(parts) >= 4:
|
|
||||||
strategy = parts[1]
|
|
||||||
test_name = parts[2]
|
|
||||||
step_num = int(parts[3])
|
|
||||||
self._show_step_detail(strategy, test_name, step_num)
|
|
||||||
else:
|
|
||||||
self.console.print(
|
|
||||||
"[red]Usage: step <strategy> <test> <step_num>[/red]"
|
|
||||||
)
|
|
||||||
elif cmd == "list tests":
|
|
||||||
self._list_tests()
|
|
||||||
elif cmd == "list strategies":
|
|
||||||
self.console.print(", ".join(self.strategies.keys()))
|
|
||||||
else:
|
|
||||||
self.console.print(f"[red]Unknown command: {cmd}[/red]")
|
|
||||||
|
|
||||||
def _list_tests(self) -> None:
|
|
||||||
"""List all failed tests."""
|
|
||||||
all_tests = set()
|
|
||||||
for analysis in self.strategies.values():
|
|
||||||
for test in analysis.failed_tests:
|
|
||||||
all_tests.add(test.test_name)
|
|
||||||
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
table = Table(title="Failed Tests Across Strategies")
|
|
||||||
table.add_column("Test", style="cyan")
|
|
||||||
for strategy in self.strategies.keys():
|
|
||||||
table.add_column(strategy, justify="center")
|
|
||||||
|
|
||||||
for test_name in sorted(all_tests):
|
|
||||||
row = [test_name]
|
|
||||||
for strategy in self.strategies.keys():
|
|
||||||
if (
|
|
||||||
test_name in self.test_comparison
|
|
||||||
and strategy in self.test_comparison[test_name]
|
|
||||||
):
|
|
||||||
row.append("[red]FAIL[/red]")
|
|
||||||
else:
|
|
||||||
row.append("[green]PASS[/green]")
|
|
||||||
table.add_row(*row)
|
|
||||||
|
|
||||||
self.console.print(table)
|
|
||||||
else:
|
|
||||||
print("\n=== Failed Tests ===")
|
|
||||||
for test_name in sorted(all_tests):
|
|
||||||
print(f" {test_name}")
|
|
||||||
|
|
||||||
def _show_step_detail(self, strategy: str, test_name: str, step_num: int) -> None:
|
|
||||||
"""Show detailed information about a specific step."""
|
|
||||||
if strategy not in self.strategies:
|
|
||||||
self._print(
|
|
||||||
f"[red]Unknown strategy: {strategy}[/red]"
|
|
||||||
if RICH_AVAILABLE
|
|
||||||
else f"Unknown strategy: {strategy}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
test = None
|
|
||||||
for t in self.strategies[strategy].failed_tests:
|
|
||||||
if t.test_name == test_name:
|
|
||||||
test = t
|
|
||||||
break
|
|
||||||
|
|
||||||
if not test:
|
|
||||||
self._print(
|
|
||||||
f"[red]Test '{test_name}' not found in {strategy}[/red]"
|
|
||||||
if RICH_AVAILABLE
|
|
||||||
else f"Test '{test_name}' not found in {strategy}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if step_num < 1 or step_num > len(test.steps):
|
|
||||||
self._print(
|
|
||||||
f"[red]Step {step_num} out of range (1-{len(test.steps)})[/red]"
|
|
||||||
if RICH_AVAILABLE
|
|
||||||
else f"Step {step_num} out of range (1-{len(test.steps)})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
step = test.steps[step_num - 1]
|
|
||||||
|
|
||||||
if RICH_AVAILABLE:
|
|
||||||
self.console.print(Panel(f"[bold]Step {step_num} Details[/bold]"))
|
|
||||||
self.console.print(f"[cyan]Tool:[/cyan] {step.tool_name}")
|
|
||||||
self.console.print(
|
|
||||||
f"[cyan]Arguments:[/cyan] {json.dumps(step.tool_args, indent=2)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if step.thoughts:
|
|
||||||
self.console.print("\n[cyan]Thoughts:[/cyan]")
|
|
||||||
for key, value in step.thoughts.items():
|
|
||||||
self.console.print(f" [dim]{key}:[/dim] {value}")
|
|
||||||
|
|
||||||
if step.tool_result:
|
|
||||||
result_str = json.dumps(step.tool_result, indent=2)[:500]
|
|
||||||
self.console.print(f"\n[cyan]Result:[/cyan] {result_str}")
|
|
||||||
|
|
||||||
self.console.print(
|
|
||||||
f"\n[cyan]Cumulative Cost:[/cyan] ${step.cumulative_cost:.4f}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"\n=== Step {step_num} Details ===")
|
|
||||||
print(f"Tool: {step.tool_name}")
|
|
||||||
print(f"Arguments: {json.dumps(step.tool_args, indent=2)}")
|
|
||||||
if step.thoughts:
|
|
||||||
print("\nThoughts:")
|
|
||||||
for key, value in step.thoughts.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
if step.tool_result:
|
|
||||||
print(f"\nResult: {json.dumps(step.tool_result, indent=2)[:500]}")
|
|
||||||
print(f"\nCumulative Cost: ${step.cumulative_cost:.4f}")
|
|
||||||
|
|
||||||
def export_markdown(self, output_path: Optional[Path] = None) -> str:
|
|
||||||
"""Export analysis to markdown format."""
|
|
||||||
lines = []
|
|
||||||
lines.append("# Benchmark Failure Analysis Report")
|
|
||||||
lines.append(f"\nGenerated: {datetime.now().isoformat()}\n")
|
|
||||||
|
|
||||||
# Summary table
|
|
||||||
lines.append("## Strategy Comparison\n")
|
|
||||||
lines.append(
|
|
||||||
"| Strategy | Tests | Passed | Failed | Success % | Avg Steps | Cost |"
|
|
||||||
)
|
|
||||||
lines.append(
|
|
||||||
"|----------|-------|--------|--------|-----------|-----------|------|"
|
|
||||||
)
|
|
||||||
for name, analysis in sorted(
|
|
||||||
self.strategies.items(), key=lambda x: x[1].success_rate, reverse=True
|
|
||||||
):
|
|
||||||
row = (
|
|
||||||
f"| {name} | {analysis.total_tests} | {analysis.passed} "
|
|
||||||
f"| {analysis.failed} | {analysis.success_rate:.1f}% "
|
|
||||||
f"| {analysis.avg_steps:.1f} | ${analysis.total_cost:.4f} |"
|
|
||||||
)
|
|
||||||
lines.append(row)
|
|
||||||
|
|
||||||
# Pattern analysis
|
|
||||||
lines.append("\n## Failure Patterns\n")
|
|
||||||
all_patterns = Counter()
|
|
||||||
for analysis in self.strategies.values():
|
|
||||||
for test in analysis.failed_tests:
|
|
||||||
for pattern in test.patterns_detected:
|
|
||||||
all_patterns[pattern] += 1
|
|
||||||
|
|
||||||
for pattern, count in all_patterns.most_common():
|
|
||||||
lines.append(f"- **{pattern.value}**: {count} occurrences")
|
|
||||||
|
|
||||||
# Failed tests by strategy
|
|
||||||
lines.append("\n## Failed Tests by Strategy\n")
|
|
||||||
for name, analysis in self.strategies.items():
|
|
||||||
if not analysis.failed_tests:
|
|
||||||
continue
|
|
||||||
lines.append(f"\n### {name}\n")
|
|
||||||
for test in analysis.failed_tests:
|
|
||||||
lines.append(f"#### {test.test_name}\n")
|
|
||||||
lines.append(f"- **Task**: {test.task[:100]}...")
|
|
||||||
lines.append(f"- **Steps**: {test.n_steps}")
|
|
||||||
patterns = ", ".join(p.value for p in test.patterns_detected)
|
|
||||||
lines.append(f"- **Patterns**: {patterns}")
|
|
||||||
tools = " -> ".join(s.tool_name for s in test.steps[:8])
|
|
||||||
lines.append(f"- **Tool sequence**: {tools}")
|
|
||||||
if test.fail_reason:
|
|
||||||
lines.append(f"- **Fail reason**: {test.fail_reason[:150]}...")
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
content = "\n".join(lines)
|
|
||||||
|
|
||||||
if output_path:
|
|
||||||
output_path.write_text(content)
|
|
||||||
self._print(
|
|
||||||
f"Markdown report saved to: {output_path}"
|
|
||||||
if not RICH_AVAILABLE
|
|
||||||
else f"[green]Markdown report saved to: {output_path}[/green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Analyze benchmark failures across prompt strategies"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-analysis",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable LLM-powered analysis",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--strategy",
|
|
||||||
type=str,
|
|
||||||
help="Focus on a specific strategy",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--test",
|
|
||||||
type=str,
|
|
||||||
help="Compare a specific test across strategies",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--interactive",
|
|
||||||
"-i",
|
|
||||||
action="store_true",
|
|
||||||
help="Run in interactive mode",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--markdown",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
const="failure_analysis.md",
|
|
||||||
help="Export to markdown (optionally specify output file)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--reports-dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to reports directory",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Find reports directory
|
|
||||||
if args.reports_dir:
|
|
||||||
reports_dir = Path(args.reports_dir)
|
|
||||||
else:
|
|
||||||
# Try to find it relative to this script
|
|
||||||
script_dir = Path(__file__).parent
|
|
||||||
reports_dir = script_dir / "reports"
|
|
||||||
if not reports_dir.exists():
|
|
||||||
reports_dir = Path.cwd() / "agbenchmark_config" / "reports"
|
|
||||||
|
|
||||||
if not reports_dir.exists():
|
|
||||||
print(f"Reports directory not found: {reports_dir}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
analyzer = FailureAnalyzer(reports_dir, use_llm=not args.no_analysis)
|
|
||||||
analyzer.analyze_all()
|
|
||||||
|
|
||||||
if not analyzer.strategies:
|
|
||||||
print("No strategy reports found.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if args.interactive:
|
|
||||||
analyzer.interactive_mode()
|
|
||||||
elif args.test:
|
|
||||||
analyzer.compare_test(args.test)
|
|
||||||
elif args.strategy:
|
|
||||||
analyzer.print_failed_tests(args.strategy)
|
|
||||||
else:
|
|
||||||
analyzer.print_summary()
|
|
||||||
analyzer.print_pattern_analysis()
|
|
||||||
analyzer.print_failed_tests()
|
|
||||||
|
|
||||||
if args.markdown:
|
|
||||||
output_path = Path(args.markdown)
|
|
||||||
analyzer.export_markdown(output_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from tabulate import tabulate
|
|
||||||
|
|
||||||
info = "-v" in sys.argv
|
|
||||||
debug = "-vv" in sys.argv
|
|
||||||
granular = "--granular" in sys.argv
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG if debug else logging.INFO if info else logging.WARNING
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Get a list of all JSON files in the directory
|
|
||||||
reports_dir = Path(__file__).parent / "reports"
|
|
||||||
if not reports_dir.exists():
|
|
||||||
print(f"No reports directory found at {reports_dir}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
report_files = [
|
|
||||||
report_file
|
|
||||||
for dir in reports_dir.iterdir()
|
|
||||||
if re.match(r"^\d{8}T\d{6}_", dir.name)
|
|
||||||
and (report_file := dir / "report.json").is_file()
|
|
||||||
]
|
|
||||||
|
|
||||||
labels = list[str]()
|
|
||||||
runs_per_label = defaultdict[str, int](lambda: 0)
|
|
||||||
suite_names = list[str]()
|
|
||||||
test_names = list[str]()
|
|
||||||
|
|
||||||
# Create a dictionary to store grouped success values by suffix and test
|
|
||||||
grouped_success_values = defaultdict[str, list[str]](list[str])
|
|
||||||
|
|
||||||
# Loop through each JSON file to collect suffixes and success values
|
|
||||||
for report_file in sorted(report_files):
|
|
||||||
with open(report_file) as f:
|
|
||||||
logger.info(f"Loading {report_file}...")
|
|
||||||
|
|
||||||
data = json.load(f)
|
|
||||||
if "tests" in data:
|
|
||||||
test_tree = data["tests"]
|
|
||||||
# Handle old format (agent_git_commit_sha) and new (config_name)
|
|
||||||
if "config" in data and "config_name" in data["config"]:
|
|
||||||
label = data["config"]["config_name"]
|
|
||||||
elif "agent_git_commit_sha" in data and "/" in data["agent_git_commit_sha"]:
|
|
||||||
label = data["agent_git_commit_sha"].rsplit("/", 1)[1][
|
|
||||||
:7
|
|
||||||
] # commit hash
|
|
||||||
else:
|
|
||||||
label = report_file.parent.name.split("_", 1)[1]
|
|
||||||
else:
|
|
||||||
# Benchmark run still in progress
|
|
||||||
test_tree = data
|
|
||||||
label = report_file.parent.name.split("_", 1)[1]
|
|
||||||
logger.info(f"Run '{label}' seems to be in progress")
|
|
||||||
|
|
||||||
runs_per_label[label] += 1
|
|
||||||
|
|
||||||
def process_test(test_name: str, test_data: dict):
|
|
||||||
result_group = grouped_success_values[f"{label}|{test_name}"]
|
|
||||||
|
|
||||||
if "tests" in test_data:
|
|
||||||
logger.debug(f"{test_name} is a test suite")
|
|
||||||
|
|
||||||
# Test suite
|
|
||||||
suite_attempted = any(
|
|
||||||
test["metrics"]["attempted"] for test in test_data["tests"].values()
|
|
||||||
)
|
|
||||||
logger.debug(f"suite_attempted: {suite_attempted}")
|
|
||||||
if not suite_attempted:
|
|
||||||
return
|
|
||||||
|
|
||||||
if test_name not in test_names:
|
|
||||||
test_names.append(test_name)
|
|
||||||
|
|
||||||
if test_data["metrics"]["percentage"] == 0:
|
|
||||||
result_indicator = "❌"
|
|
||||||
else:
|
|
||||||
highest_difficulty = test_data["metrics"]["highest_difficulty"]
|
|
||||||
result_indicator = {
|
|
||||||
"interface": "🔌",
|
|
||||||
"novice": "🌑",
|
|
||||||
"basic": "🌒",
|
|
||||||
"intermediate": "🌓",
|
|
||||||
"advanced": "🌔",
|
|
||||||
"hard": "🌕",
|
|
||||||
}[highest_difficulty]
|
|
||||||
|
|
||||||
logger.debug(f"result group: {result_group}")
|
|
||||||
logger.debug(f"runs_per_label: {runs_per_label[label]}")
|
|
||||||
if len(result_group) + 1 < runs_per_label[label]:
|
|
||||||
result_group.extend(
|
|
||||||
["❔"] * (runs_per_label[label] - len(result_group) - 1)
|
|
||||||
)
|
|
||||||
result_group.append(result_indicator)
|
|
||||||
logger.debug(f"result group (after): {result_group}")
|
|
||||||
|
|
||||||
if granular:
|
|
||||||
for test_name, test in test_data["tests"].items():
|
|
||||||
process_test(test_name, test)
|
|
||||||
return
|
|
||||||
|
|
||||||
test_metrics = test_data["metrics"]
|
|
||||||
result_indicator = "❔"
|
|
||||||
|
|
||||||
if "attempted" not in test_metrics:
|
|
||||||
return
|
|
||||||
elif test_metrics["attempted"]:
|
|
||||||
if test_name not in test_names:
|
|
||||||
test_names.append(test_name)
|
|
||||||
|
|
||||||
# Handle old format (success: bool) and new (success_percentage)
|
|
||||||
if "success" in test_metrics:
|
|
||||||
success_value = test_metrics["success"]
|
|
||||||
elif "success_percentage" in test_metrics:
|
|
||||||
success_value = test_metrics["success_percentage"] >= 100.0
|
|
||||||
else:
|
|
||||||
success_value = False
|
|
||||||
result_indicator = {True: "✅", False: "❌"}[success_value]
|
|
||||||
|
|
||||||
if len(result_group) + 1 < runs_per_label[label]:
|
|
||||||
result_group.extend(
|
|
||||||
[" "] * (runs_per_label[label] - len(result_group) - 1)
|
|
||||||
)
|
|
||||||
result_group.append(result_indicator)
|
|
||||||
|
|
||||||
for test_name, suite in test_tree.items():
|
|
||||||
try:
|
|
||||||
process_test(test_name, suite)
|
|
||||||
except KeyError:
|
|
||||||
print(f"{test_name}.metrics: {suite['metrics']}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if label not in labels:
|
|
||||||
labels.append(label)
|
|
||||||
|
|
||||||
# Create headers
|
|
||||||
headers = ["Test Name"] + list(labels)
|
|
||||||
|
|
||||||
# Prepare data for tabulation
|
|
||||||
table_data = list[list[str]]()
|
|
||||||
for test_name in test_names:
|
|
||||||
row = [test_name]
|
|
||||||
for label in labels:
|
|
||||||
results = grouped_success_values.get(f"{label}|{test_name}", ["❔"])
|
|
||||||
if len(results) < runs_per_label[label]:
|
|
||||||
results.extend(["❔"] * (runs_per_label[label] - len(results)))
|
|
||||||
if len(results) > 1 and all(r == "❔" for r in results):
|
|
||||||
results.clear()
|
|
||||||
row.append(" ".join(results))
|
|
||||||
table_data.append(row)
|
|
||||||
|
|
||||||
# Print tabulated data
|
|
||||||
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
# Challenges Data Schema of Benchmark
|
|
||||||
|
|
||||||
## General challenges
|
|
||||||
|
|
||||||
Input:
|
|
||||||
|
|
||||||
- **name** (str): Name of the challenge.
|
|
||||||
- **category** (str[]): Category of the challenge such as 'basic', 'retrieval', 'comprehension', etc. _this is not currently used. for the future it may be needed_
|
|
||||||
- **task** (str): The task that the agent needs to solve.
|
|
||||||
- **dependencies** (str[]): The dependencies that the challenge needs to run. Needs to be the full node to the test function.
|
|
||||||
- **ground** (dict): The ground truth.
|
|
||||||
- **answer** (str): The raw text of the ground truth answer.
|
|
||||||
- **should_contain** (list): The exact strings that are required in the final answer.
|
|
||||||
- **should_not_contain** (list): The exact strings that should not be in the final answer.
|
|
||||||
- **files** (list): Files that are used for retrieval. Can specify file here or an extension.
|
|
||||||
- **mock** (dict): Mock response for testing.
|
|
||||||
- **mock_func** (str): Function to mock the agent's response. This is used for testing purposes.
|
|
||||||
- **mock_task** (str): Task to provide for the mock function.
|
|
||||||
- **info** (dict): Additional info about the challenge.
|
|
||||||
- **difficulty** (str): The difficulty of this query.
|
|
||||||
- **description** (str): Description of the challenge.
|
|
||||||
- **side_effects** (str[]): Describes the effects of the challenge.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"category": ["basic"],
|
|
||||||
"task": "Print the capital of America to a .txt file",
|
|
||||||
"dependencies": ["TestWriteFile"], // the class name of the test
|
|
||||||
"ground": {
|
|
||||||
"answer": "Washington",
|
|
||||||
"should_contain": ["Washington"],
|
|
||||||
"should_not_contain": ["New York", "Los Angeles", "San Francisco"],
|
|
||||||
"files": [".txt"],
|
|
||||||
"eval": {
|
|
||||||
"type": "llm" or "file" or "python",
|
|
||||||
"scoring": "percentage" or "scale" or "binary", // only if the type is llm
|
|
||||||
"template": "rubric" or "reference" or "custom" // only if the type is llm
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"info": {
|
|
||||||
"difficulty": "basic",
|
|
||||||
"description": "Tests the writing to file",
|
|
||||||
"side_effects": ["tests if there is in fact an LLM attached"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Evals
|
|
||||||
|
|
||||||
This is the method of evaluation for a challenge.
|
|
||||||
|
|
||||||
### file
|
|
||||||
|
|
||||||
This is the default method of evaluation. It will compare the files specified in "files" field to the "should_contain" and "should_not_contain" ground truths.
|
|
||||||
|
|
||||||
### python
|
|
||||||
|
|
||||||
This runs a python function in the specified "files" which captures the print statements to be scored using the "should_contain" and "should_not_contain" ground truths.
|
|
||||||
|
|
||||||
### llm
|
|
||||||
|
|
||||||
This uses a language model to evaluate the answer.
|
|
||||||
|
|
||||||
- There are 3 different templates - "rubric", "reference", and "custom". "rubric" will evaluate based on a rubric you provide in the "answer" field. "reference" will evaluate based on the ideal reference response in "answer". "custom" will not use any predefined scoring method, the prompt will be what you put in "answer".
|
|
||||||
- The "scoring" field is used to determine how to score the answer. "percentage" will assign a percentage out of 100. "scale" will score the answer 1-10. "binary" will score the answer based on whether the answer is correct or not.
|
|
||||||
- You can still use the "should_contain" and "should_not_contain" fields to directly match the answer along with the llm eval.
|
|
||||||
|
|
||||||
## Add files to challenges:
|
|
||||||
|
|
||||||
### artifacts_in
|
|
||||||
|
|
||||||
This folder contains all the files you want the agent to have in its workspace BEFORE the challenge starts
|
|
||||||
|
|
||||||
### artifacts_out
|
|
||||||
|
|
||||||
This folder contains all the files you would like the agent to generate. This folder is used to mock the agent.
|
|
||||||
This allows to run agbenchmark --test=TestExample --mock and make sure our challenge actually works.
|
|
||||||
|
|
||||||
### custom_python
|
|
||||||
|
|
||||||
This folder contains files that will be copied into the agent's workspace and run after the challenge is completed.
|
|
||||||
For example we can have a test.py in it and run this file in the workspace to easily import code generated by the agent.
|
|
||||||
Example: TestBasicCodeGeneration challenge.
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
# This is the official challenge library for https://github.com/Significant-Gravitas/Auto-GPT-Benchmarks
|
|
||||||
|
|
||||||
The goal of this repo is to provide easy challenge creation for test driven development with the Auto-GPT-Benchmarks package. This is essentially a library to craft challenges using a dsl (jsons in this case).
|
|
||||||
|
|
||||||
This is the up to date dependency graph: https://sapphire-denys-23.tiiny.site/
|
|
||||||
|
|
||||||
### How to use
|
|
||||||
|
|
||||||
Make sure you have the package installed with `pip install agbenchmark`.
|
|
||||||
|
|
||||||
If you would just like to use the default challenges, don't worry about this repo. Just install the package and you will have access to the default challenges.
|
|
||||||
|
|
||||||
To add new challenges as you develop, add this repo as a submodule to your `project/agbenchmark` folder. Any new challenges you add within the submodule will get registered automatically.
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import glob
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .base import BaseChallenge, ChallengeInfo
|
|
||||||
from .builtin import OPTIONAL_CATEGORIES
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_challenge_from_source_uri(source_uri: str) -> type[BaseChallenge]:
|
|
||||||
from .builtin import BuiltinChallenge
|
|
||||||
from .webarena import WebArenaChallenge
|
|
||||||
|
|
||||||
provider_prefix = source_uri.split("/", 1)[0]
|
|
||||||
|
|
||||||
if provider_prefix == BuiltinChallenge.SOURCE_URI_PREFIX:
|
|
||||||
return BuiltinChallenge.from_source_uri(source_uri)
|
|
||||||
|
|
||||||
if provider_prefix == WebArenaChallenge.SOURCE_URI_PREFIX:
|
|
||||||
return WebArenaChallenge.from_source_uri(source_uri)
|
|
||||||
|
|
||||||
raise ValueError(f"Cannot resolve source_uri '{source_uri}'")
|
|
||||||
|
|
||||||
|
|
||||||
def get_unique_categories() -> set[str]:
|
|
||||||
"""
|
|
||||||
Reads all challenge spec files and returns a set of all their categories.
|
|
||||||
"""
|
|
||||||
categories = set()
|
|
||||||
|
|
||||||
challenges_dir = Path(__file__).parent
|
|
||||||
glob_path = f"{challenges_dir}/**/data.json"
|
|
||||||
|
|
||||||
for data_file in glob.glob(glob_path, recursive=True):
|
|
||||||
with open(data_file, "r") as f:
|
|
||||||
try:
|
|
||||||
challenge_data = json.load(f)
|
|
||||||
categories.update(challenge_data.get("category", []))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"Error: {data_file} is not a valid JSON file.")
|
|
||||||
continue
|
|
||||||
except IOError:
|
|
||||||
logger.error(f"IOError: file could not be read: {data_file}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return categories
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseChallenge",
|
|
||||||
"ChallengeInfo",
|
|
||||||
"get_unique_categories",
|
|
||||||
"OPTIONAL_CATEGORIES",
|
|
||||||
]
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
Hello World!
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
Hello World!
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
Hello World!
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user