mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 23:05:17 -05:00
Compare commits
42 Commits
make-old-w
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6573d987ea | ||
|
|
ae8ce8b4ca | ||
|
|
81c1524658 | ||
|
|
f2ead70f3d | ||
|
|
7d4c020a9b | ||
|
|
e596ea87cb | ||
|
|
81f8290f01 | ||
|
|
6467f6734f | ||
|
|
5a30d11416 | ||
|
|
1f4105e8f9 | ||
|
|
caf9ff34e6 | ||
|
|
e8fc8ee623 | ||
|
|
1a16e203b8 | ||
|
|
5dae303ce0 | ||
|
|
6cbfbdd013 | ||
|
|
0c6fa60436 | ||
|
|
b04e916c23 | ||
|
|
1a32ba7d9a | ||
|
|
deccc26f1f | ||
|
|
9e38bd5b78 | ||
|
|
a329831b0b | ||
|
|
98dd1a9480 | ||
|
|
9c7c598c7d | ||
|
|
728c40def5 | ||
|
|
cd64562e1b | ||
|
|
8fddc9d71f | ||
|
|
3d1cd03fc8 | ||
|
|
e7ebe42306 | ||
|
|
e0fab7e34e | ||
|
|
29ee85c86f | ||
|
|
85b6520710 | ||
|
|
bfa942e032 | ||
|
|
11256076d8 | ||
|
|
3ca2387631 | ||
|
|
ed07f02738 | ||
|
|
b121030c94 | ||
|
|
c22c18374d | ||
|
|
e40233a3ac | ||
|
|
3ae5eabf9d | ||
|
|
a077ba9f03 | ||
|
|
5401d54eaa | ||
|
|
5ac89d7c0b |
73
.github/workflows/classic-autogpt-ci.yml
vendored
73
.github/workflows/classic-autogpt-ci.yml
vendored
@@ -6,15 +6,11 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-autogpt-ci.yml'
|
- '.github/workflows/classic-autogpt-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/direct_benchmark/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('classic-autogpt-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -23,22 +19,47 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/original_autogpt
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
@@ -50,23 +71,41 @@ jobs:
|
|||||||
git config --global user.name "Auto-GPT-Bot"
|
git config --global user.name "Auto-GPT-Bot"
|
||||||
git config --global user.email "github-bot@agpt.co"
|
git config --global user.email "github-bot@agpt.co"
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- id: get_date
|
- id: get_date
|
||||||
name: Get date
|
name: Get date
|
||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/original_autogpt/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -77,12 +116,12 @@ jobs:
|
|||||||
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--numprocesses=logical --durations=10 \
|
--numprocesses=logical --durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
original_autogpt/tests/unit original_autogpt/tests/integration
|
tests/unit tests/integration
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -96,11 +135,11 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: autogpt-agent
|
flags: autogpt-agent,${{ runner.os }}
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/original_autogpt/logs/
|
||||||
|
|||||||
36
.github/workflows/classic-autogpts-ci.yml
vendored
36
.github/workflows/classic-autogpts-ci.yml
vendored
@@ -11,6 +11,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
@@ -19,6 +22,9 @@ on:
|
|||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
|
- 'classic/run'
|
||||||
|
- 'classic/cli.py'
|
||||||
|
- 'classic/setup.py'
|
||||||
- '!**/*.md'
|
- '!**/*.md'
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
@@ -29,9 +35,13 @@ defaults:
|
|||||||
jobs:
|
jobs:
|
||||||
serve-agent-protocol:
|
serve-agent-protocol:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
agent-name: [ original_autogpt ]
|
||||||
|
fail-fast: false
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -45,22 +55,22 @@ jobs:
|
|||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ env.min-python-version }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
|
working-directory: ./classic/${{ matrix.agent-name }}/
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Run regression tests
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run smoke tests with direct-benchmark
|
|
||||||
run: |
|
run: |
|
||||||
poetry run direct-benchmark run \
|
./run agent start ${{ matrix.agent-name }}
|
||||||
--strategies one_shot \
|
cd ${{ matrix.agent-name }}
|
||||||
--models claude \
|
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||||
--tests ReadFile,WriteFile \
|
poetry run agbenchmark --test=WriteFile
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
AGENT_NAME: ${{ matrix.agent-name }}
|
||||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||||
NONINTERACTIVE_MODE: "true"
|
HELICONE_CACHE_ENABLED: false
|
||||||
CI: true
|
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||||
|
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||||
|
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
194
.github/workflows/classic-benchmark-ci.yml
vendored
194
.github/workflows/classic-benchmark-ci.yml
vendored
@@ -1,21 +1,17 @@
|
|||||||
name: Classic - Direct Benchmark CI
|
name: Classic - AGBenchmark CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master, dev, ci-test* ]
|
branches: [ master, dev, ci-test* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/benchmark/agbenchmark/challenges/**'
|
- '!classic/benchmark/reports/**'
|
||||||
- 'classic/original_autogpt/**'
|
|
||||||
- 'classic/forge/**'
|
|
||||||
- .github/workflows/classic-benchmark-ci.yml
|
- .github/workflows/classic-benchmark-ci.yml
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -27,16 +23,23 @@ defaults:
|
|||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
env:
|
env:
|
||||||
min-python-version: '3.12'
|
min-python-version: '3.10'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
benchmark-tests:
|
test:
|
||||||
runs-on: ubuntu-latest
|
permissions:
|
||||||
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/benchmark
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -44,88 +47,71 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
submodules: true
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python ${{ env.min-python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ env.min-python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/benchmark/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
|
if: runner.os != 'Windows'
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
- name: Install dependencies
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
|
|
||||||
- name: Run basic benchmark tests
|
- name: Run pytest with coverage
|
||||||
run: |
|
run: |
|
||||||
echo "Testing ReadFile challenge with one_shot strategy..."
|
poetry run pytest -vv \
|
||||||
poetry run direct-benchmark run \
|
--cov=agbenchmark --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--fresh \
|
--durations=10 \
|
||||||
--strategies one_shot \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
--models claude \
|
tests
|
||||||
--tests ReadFile \
|
|
||||||
--json
|
|
||||||
|
|
||||||
echo "Testing WriteFile challenge..."
|
|
||||||
poetry run direct-benchmark run \
|
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--tests WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test category filtering
|
- name: Upload test results to Codecov
|
||||||
run: |
|
if: ${{ !cancelled() }} # Run even if tests fail
|
||||||
echo "Testing coding category..."
|
uses: codecov/test-results-action@v1
|
||||||
poetry run direct-benchmark run \
|
with:
|
||||||
--fresh \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--strategies one_shot \
|
|
||||||
--models claude \
|
|
||||||
--categories coding \
|
|
||||||
--tests ReadFile,WriteFile \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
- name: Test multiple strategies
|
- name: Upload coverage reports to Codecov
|
||||||
run: |
|
uses: codecov/codecov-action@v5
|
||||||
echo "Testing multiple strategies..."
|
with:
|
||||||
poetry run direct-benchmark run \
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
--fresh \
|
flags: agbenchmark,${{ runner.os }}
|
||||||
--strategies one_shot,plan_execute \
|
|
||||||
--models claude \
|
|
||||||
--tests ReadFile \
|
|
||||||
--parallel 2 \
|
|
||||||
--json
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
NONINTERACTIVE_MODE: "true"
|
|
||||||
|
|
||||||
# Run regression tests on maintain challenges
|
self-test-with-agent:
|
||||||
regression-tests:
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
timeout-minutes: 45
|
strategy:
|
||||||
if: github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev'
|
matrix:
|
||||||
defaults:
|
agent-name: [forge]
|
||||||
run:
|
fail-fast: false
|
||||||
shell: bash
|
timeout-minutes: 20
|
||||||
working-directory: classic
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -140,23 +126,51 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: |
|
run: |
|
||||||
curl -sSL https://install.python-poetry.org | python3 -
|
curl -sSL https://install.python-poetry.org | python -
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: poetry install
|
|
||||||
|
|
||||||
- name: Run regression tests
|
- name: Run regression tests
|
||||||
|
working-directory: classic
|
||||||
run: |
|
run: |
|
||||||
echo "Running regression tests (previously beaten challenges)..."
|
./run agent start ${{ matrix.agent-name }}
|
||||||
poetry run direct-benchmark run \
|
cd ${{ matrix.agent-name }}
|
||||||
--fresh \
|
|
||||||
--strategies one_shot \
|
set +e # Ignore non-zero exit codes and continue execution
|
||||||
--models claude \
|
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||||
--maintain \
|
poetry run agbenchmark --maintain --mock
|
||||||
--parallel 4 \
|
EXIT_CODE=$?
|
||||||
--json
|
set -e # Stop ignoring non-zero exit codes
|
||||||
|
# Check if the exit code was 5, and if so, exit with 0 instead
|
||||||
|
if [ $EXIT_CODE -eq 5 ]; then
|
||||||
|
echo "regression_tests.json is empty."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock"
|
||||||
|
poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=data"
|
||||||
|
poetry run agbenchmark --mock --category=data
|
||||||
|
|
||||||
|
echo "Running the following command: poetry run agbenchmark --mock --category=coding"
|
||||||
|
poetry run agbenchmark --mock --category=coding
|
||||||
|
|
||||||
|
# echo "Running the following command: poetry run agbenchmark --test=WriteFile"
|
||||||
|
# poetry run agbenchmark --test=WriteFile
|
||||||
|
cd ../benchmark
|
||||||
|
poetry install
|
||||||
|
echo "Adding the BUILD_SKILL_TREE environment variable. This will attempt to add new elements in the skill tree. If new elements are added, the CI fails because they should have been pushed"
|
||||||
|
export BUILD_SKILL_TREE=true
|
||||||
|
|
||||||
|
# poetry run agbenchmark --mock
|
||||||
|
|
||||||
|
# CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../classic/frontend/assets)') || echo "No diffs"
|
||||||
|
# if [ ! -z "$CHANGED" ]; then
|
||||||
|
# echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||||
|
# echo "$CHANGED"
|
||||||
|
# exit 1
|
||||||
|
# else
|
||||||
|
# echo "No unstaged changes."
|
||||||
|
# fi
|
||||||
env:
|
env:
|
||||||
CI: true
|
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
NONINTERACTIVE_MODE: "true"
|
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||||
|
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||||
|
|||||||
182
.github/workflows/classic-forge-ci.yml
vendored
182
.github/workflows/classic-forge-ci.yml
vendored
@@ -6,11 +6,13 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master, dev, release-* ]
|
branches: [ master, dev, release-* ]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/classic-forge-ci.yml'
|
- '.github/workflows/classic-forge-ci.yml'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
group: ${{ format('forge-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||||
@@ -19,38 +21,115 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
working-directory: classic/forge
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: ubuntu-latest
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||||
|
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Start MinIO service
|
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||||
|
# - name: Set up Docker (macOS)
|
||||||
|
# if: runner.os == 'macOS'
|
||||||
|
# uses: crazy-max/ghaction-setup-docker@v3
|
||||||
|
|
||||||
|
- name: Start MinIO service (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
working-directory: '.'
|
working-directory: '.'
|
||||||
run: |
|
run: |
|
||||||
docker pull minio/minio:edge-cicd
|
docker pull minio/minio:edge-cicd
|
||||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||||
|
|
||||||
|
- name: Start MinIO service (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
working-directory: ${{ runner.temp }}
|
||||||
|
run: |
|
||||||
|
brew install minio/stable/minio
|
||||||
|
mkdir data
|
||||||
|
minio server ./data &
|
||||||
|
|
||||||
|
# No MinIO on Windows:
|
||||||
|
# - Windows doesn't support running Linux Docker containers
|
||||||
|
# - It doesn't seem possible to start background processes on Windows. They are
|
||||||
|
# killed after the step returns.
|
||||||
|
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: true
|
||||||
|
|
||||||
- name: Set up Python 3.12
|
- name: Checkout cassettes
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
env:
|
||||||
|
PR_BASE: ${{ github.event.pull_request.base.ref }}
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
cassette_base_branch="${PR_BASE}"
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
|
||||||
|
if ! git ls-remote --exit-code --heads origin $cassette_base_branch ; then
|
||||||
|
cassette_base_branch="master"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if git ls-remote --exit-code --heads origin $cassette_branch ; then
|
||||||
|
git fetch origin $cassette_branch
|
||||||
|
git fetch origin $cassette_base_branch
|
||||||
|
|
||||||
|
git checkout $cassette_branch
|
||||||
|
|
||||||
|
# Pick non-conflicting cassette updates from the base branch
|
||||||
|
git merge --no-commit --strategy-option=ours origin/$cassette_base_branch
|
||||||
|
echo "Using cassettes from mirror branch '$cassette_branch'," \
|
||||||
|
"synced to upstream branch '$cassette_base_branch'."
|
||||||
|
else
|
||||||
|
git checkout -b $cassette_branch
|
||||||
|
echo "Branch '$cassette_branch' does not exist in cassette submodule." \
|
||||||
|
"Using cassettes from '$cassette_base_branch'."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
|
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||||
|
if: runner.os != 'Windows'
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('classic/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('classic/forge/poetry.lock') }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry (Unix)
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
if: runner.os != 'Windows'
|
||||||
|
run: |
|
||||||
|
curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||||
|
PATH="$HOME/.local/bin:$PATH"
|
||||||
|
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install Poetry (Windows)
|
||||||
|
if: runner.os == 'Windows'
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||||
|
|
||||||
|
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||||
|
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry install
|
||||||
@@ -61,15 +140,12 @@ jobs:
|
|||||||
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
--cov=forge --cov-branch --cov-report term-missing --cov-report xml \
|
||||||
--durations=10 \
|
--durations=10 \
|
||||||
--junitxml=junit.xml -o junit_family=legacy \
|
--junitxml=junit.xml -o junit_family=legacy \
|
||||||
forge/forge forge/tests
|
forge
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true
|
||||||
PLAIN_OUTPUT: True
|
PLAIN_OUTPUT: True
|
||||||
# API keys - tests that need these will skip if not available
|
|
||||||
# Secrets are not available to fork PRs (GitHub security feature)
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||||
S3_ENDPOINT_URL: http://127.0.0.1:9000
|
|
||||||
AWS_ACCESS_KEY_ID: minioadmin
|
AWS_ACCESS_KEY_ID: minioadmin
|
||||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||||
|
|
||||||
@@ -83,11 +159,85 @@ jobs:
|
|||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
flags: forge
|
flags: forge,${{ runner.os }}
|
||||||
|
|
||||||
|
- id: setup_git_auth
|
||||||
|
name: Set up git token authentication
|
||||||
|
# Cassettes may be pushed even when tests fail
|
||||||
|
if: success() || failure()
|
||||||
|
run: |
|
||||||
|
config_key="http.${{ github.server_url }}/.extraheader"
|
||||||
|
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||||
|
else
|
||||||
|
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||||
|
fi
|
||||||
|
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
git config "$config_key" \
|
||||||
|
"Authorization: Basic $base64_pat"
|
||||||
|
|
||||||
|
echo "config_key=$config_key" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- id: push_cassettes
|
||||||
|
name: Push updated cassettes
|
||||||
|
# For pull requests, push updated cassettes even when tests fail
|
||||||
|
if: github.event_name == 'push' || (! github.event.pull_request.head.repo.fork && (success() || failure()))
|
||||||
|
env:
|
||||||
|
PR_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||||
|
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||||
|
run: |
|
||||||
|
if [ "${{ startsWith(github.event_name, 'pull_request') }}" = "true" ]; then
|
||||||
|
is_pull_request=true
|
||||||
|
cassette_branch="${PR_AUTHOR}-${PR_BRANCH}"
|
||||||
|
else
|
||||||
|
cassette_branch="${{ github.ref_name }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd tests/vcr_cassettes
|
||||||
|
# Commit & push changes to cassettes if any
|
||||||
|
if ! git diff --quiet; then
|
||||||
|
git add .
|
||||||
|
git commit -m "Auto-update cassettes"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
if [ ! $is_pull_request ]; then
|
||||||
|
cd ../..
|
||||||
|
git add tests/vcr_cassettes
|
||||||
|
git commit -m "Update cassette submodule"
|
||||||
|
git push origin HEAD:$cassette_branch
|
||||||
|
fi
|
||||||
|
echo "updated=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "updated=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "No cassette changes to commit"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Post Set up git token auth
|
||||||
|
if: steps.setup_git_auth.outcome == 'success'
|
||||||
|
run: |
|
||||||
|
git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
git submodule foreach git config --unset-all '${{ steps.setup_git_auth.outputs.config_key }}'
|
||||||
|
|
||||||
|
- name: Apply "behaviour change" label and comment on PR
|
||||||
|
if: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||||
|
run: |
|
||||||
|
PR_NUMBER="${{ github.event.pull_request.number }}"
|
||||||
|
TOKEN="${{ secrets.PAT_REVIEW }}"
|
||||||
|
REPO="${{ github.repository }}"
|
||||||
|
|
||||||
|
if [[ "${{ steps.push_cassettes.outputs.updated }}" == "true" ]]; then
|
||||||
|
echo "Adding label and comment..."
|
||||||
|
echo $TOKEN | gh auth login --with-token
|
||||||
|
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||||
|
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Upload logs to artifact
|
- name: Upload logs to artifact
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: test-logs
|
name: test-logs
|
||||||
path: classic/logs/
|
path: classic/forge/logs/
|
||||||
|
|||||||
2
.github/workflows/classic-frontend-ci.yml
vendored
2
.github/workflows/classic-frontend-ci.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||||
if: github.event_name == 'push'
|
if: github.event_name == 'push'
|
||||||
uses: peter-evans/create-pull-request@v7
|
uses: peter-evans/create-pull-request@v8
|
||||||
with:
|
with:
|
||||||
add-paths: classic/frontend/build/web
|
add-paths: classic/frontend/build/web
|
||||||
base: ${{ github.ref_name }}
|
base: ${{ github.ref_name }}
|
||||||
|
|||||||
67
.github/workflows/classic-python-checks.yml
vendored
67
.github/workflows/classic-python-checks.yml
vendored
@@ -7,9 +7,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
pull_request:
|
pull_request:
|
||||||
@@ -18,9 +16,7 @@ on:
|
|||||||
- '.github/workflows/classic-python-checks-ci.yml'
|
- '.github/workflows/classic-python-checks-ci.yml'
|
||||||
- 'classic/original_autogpt/**'
|
- 'classic/original_autogpt/**'
|
||||||
- 'classic/forge/**'
|
- 'classic/forge/**'
|
||||||
- 'classic/direct_benchmark/**'
|
- 'classic/benchmark/**'
|
||||||
- 'classic/pyproject.toml'
|
|
||||||
- 'classic/poetry.lock'
|
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- '!classic/forge/tests/vcr_cassettes'
|
- '!classic/forge/tests/vcr_cassettes'
|
||||||
|
|
||||||
@@ -31,13 +27,44 @@ concurrency:
|
|||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: classic
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
get-changed-parts:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- id: changes-in
|
||||||
|
name: Determine affected subprojects
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
original_autogpt:
|
||||||
|
- classic/original_autogpt/autogpt/**
|
||||||
|
- classic/original_autogpt/tests/**
|
||||||
|
- classic/original_autogpt/poetry.lock
|
||||||
|
forge:
|
||||||
|
- classic/forge/forge/**
|
||||||
|
- classic/forge/tests/**
|
||||||
|
- classic/forge/poetry.lock
|
||||||
|
benchmark:
|
||||||
|
- classic/benchmark/agbenchmark/**
|
||||||
|
- classic/benchmark/tests/**
|
||||||
|
- classic/benchmark/poetry.lock
|
||||||
|
outputs:
|
||||||
|
changed-parts: ${{ steps.changes-in.outputs.changes }}
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -54,31 +81,42 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Lint
|
# Lint
|
||||||
|
|
||||||
- name: Lint (isort)
|
- name: Lint (isort)
|
||||||
run: poetry run isort --check .
|
run: poetry run isort --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Black)
|
- name: Lint (Black)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run black --check .
|
run: poetry run black --check .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
- name: Lint (Flake8)
|
- name: Lint (Flake8)
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run flake8 .
|
run: poetry run flake8 .
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|
||||||
types:
|
types:
|
||||||
|
needs: get-changed-parts
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
min-python-version: "3.12"
|
min-python-version: "3.10"
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
sub-package: ${{ fromJson(needs.get-changed-parts.outputs.changed-parts) }}
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -95,16 +133,19 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: ${{ runner.os }}-poetry-${{ hashFiles('classic/poetry.lock') }}
|
key: ${{ runner.os }}-poetry-${{ hashFiles(format('{0}/poetry.lock', matrix.sub-package)) }}
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install Poetry
|
||||||
run: curl -sSL https://install.python-poetry.org | python3 -
|
run: curl -sSL https://install.python-poetry.org | python3 -
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: poetry install
|
run: poetry -C classic/${{ matrix.sub-package }} install
|
||||||
|
|
||||||
# Typecheck
|
# Typecheck
|
||||||
|
|
||||||
- name: Typecheck
|
- name: Typecheck
|
||||||
if: success() || failure()
|
if: success() || failure()
|
||||||
run: poetry run pyright
|
run: poetry run pyright
|
||||||
|
working-directory: classic/${{ matrix.sub-package }}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Get CI failure details
|
- name: Get CI failure details
|
||||||
id: failure_details
|
id: failure_details
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const run = await github.rest.actions.getWorkflowRun({
|
const run = await github.rest.actions.getWorkflowRun({
|
||||||
|
|||||||
9
.github/workflows/claude-dependabot.yml
vendored
9
.github/workflows/claude-dependabot.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -91,7 +91,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
@@ -309,6 +309,7 @@ jobs:
|
|||||||
uses: anthropics/claude-code-action@v1
|
uses: anthropics/claude-code-action@v1
|
||||||
with:
|
with:
|
||||||
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }}
|
||||||
|
allowed_bots: "dependabot[bot]"
|
||||||
claude_args: |
|
claude_args: |
|
||||||
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
--allowedTools "Bash(npm:*),Bash(pnpm:*),Bash(poetry:*),Bash(git:*),Edit,Replace,NotebookEditCell,mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*), Bash(gh pr diff:*), Bash(gh pr view:*)"
|
||||||
prompt: |
|
prompt: |
|
||||||
|
|||||||
8
.github/workflows/claude.yml
vendored
8
.github/workflows/claude.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -94,7 +94,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -140,7 +140,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
8
.github/workflows/copilot-setup-steps.yml
vendored
8
.github/workflows/copilot-setup-steps.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
python-version: "3.11" # Use standard version matching CI
|
python-version: "3.11" # Use standard version matching CI
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -76,7 +76,7 @@ jobs:
|
|||||||
|
|
||||||
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
# Frontend Node.js/pnpm setup (mirrors platform-frontend-ci.yml)
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
echo "PNPM_HOME=$HOME/.pnpm-store" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache frontend dependencies
|
- name: Cache frontend dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
key: ${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}
|
||||||
@@ -132,7 +132,7 @@ jobs:
|
|||||||
# Phase 1: Cache and load Docker images for faster setup
|
# Phase 1: Cache and load Docker images for faster setup
|
||||||
- name: Set up Docker image cache
|
- name: Set up Docker image cache
|
||||||
id: docker-cache
|
id: docker-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/docker-cache
|
path: ~/docker-cache
|
||||||
# Use a versioned key for cache invalidation when image list changes
|
# Use a versioned key for cache invalidation when image list changes
|
||||||
|
|||||||
2
.github/workflows/docs-block-sync.yml
vendored
2
.github/workflows/docs-block-sync.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-claude-review.yml
vendored
2
.github/workflows/docs-claude-review.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/docs-enhance.yml
vendored
2
.github/workflows/docs-enhance.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -88,7 +88,7 @@ jobs:
|
|||||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Set up Python dependency cache
|
- name: Set up Python dependency cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.cache/pypoetry
|
path: ~/.cache/pypoetry
|
||||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/backend/poetry.lock') }}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Check comment permissions and deployment status
|
- name: Check comment permissions and deployment status
|
||||||
id: check_status
|
id: check_status
|
||||||
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
if: github.event_name == 'issue_comment' && github.event.issue.pull_request
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const commentBody = context.payload.comment.body.trim();
|
const commentBody = context.payload.comment.body.trim();
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post permission denied comment
|
- name: Post permission denied comment
|
||||||
if: steps.check_status.outputs.permission_denied == 'true'
|
if: steps.check_status.outputs.permission_denied == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
- name: Get PR details for deployment
|
- name: Get PR details for deployment
|
||||||
id: pr_details
|
id: pr_details
|
||||||
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true' || steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const pr = await github.rest.pulls.get({
|
const pr = await github.rest.pulls.get({
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post deploy success comment
|
- name: Post deploy success comment
|
||||||
if: steps.check_status.outputs.should_deploy == 'true'
|
if: steps.check_status.outputs.should_deploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -126,7 +126,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Post undeploy success comment
|
- name: Post undeploy success comment
|
||||||
if: steps.check_status.outputs.should_undeploy == 'true'
|
if: steps.check_status.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
@@ -139,7 +139,7 @@ jobs:
|
|||||||
- name: Check deployment status on PR close
|
- name: Check deployment status on PR close
|
||||||
id: check_pr_close
|
id: check_pr_close
|
||||||
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
if: github.event_name == 'pull_request' && github.event.action == 'closed'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const comments = await github.rest.issues.listComments({
|
const comments = await github.rest.issues.listComments({
|
||||||
@@ -187,7 +187,7 @@ jobs:
|
|||||||
github.event_name == 'pull_request' &&
|
github.event_name == 'pull_request' &&
|
||||||
github.event.action == 'closed' &&
|
github.event.action == 'closed' &&
|
||||||
steps.check_pr_close.outputs.should_undeploy == 'true'
|
steps.check_pr_close.outputs.should_undeploy == 'true'
|
||||||
uses: actions/github-script@v7
|
uses: actions/github-script@v8
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
await github.rest.issues.createComment({
|
await github.rest.issues.createComment({
|
||||||
|
|||||||
38
.github/workflows/platform-frontend-ci.yml
vendored
38
.github/workflows/platform-frontend-ci.yml
vendored
@@ -27,13 +27,22 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
outputs:
|
||||||
cache-key: ${{ steps.cache-key.outputs.key }}
|
cache-key: ${{ steps.cache-key.outputs.key }}
|
||||||
|
components-changed: ${{ steps.filter.outputs.components }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Check for component changes
|
||||||
|
uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
components:
|
||||||
|
- 'autogpt_platform/frontend/src/components/**'
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -45,7 +54,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -65,7 +74,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -73,7 +82,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -90,8 +99,11 @@ jobs:
|
|||||||
chromatic:
|
chromatic:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: setup
|
needs: setup
|
||||||
# Only run on dev branch pushes or PRs targeting dev
|
# Disabled: to re-enable, remove 'false &&' from the condition below
|
||||||
if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev'
|
if: >-
|
||||||
|
false
|
||||||
|
&& (github.ref == 'refs/heads/dev' || github.base_ref == 'dev')
|
||||||
|
&& needs.setup.outputs.components-changed == 'true'
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -100,7 +112,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -108,7 +120,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -141,7 +153,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -164,7 +176,7 @@ jobs:
|
|||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Cache Docker layers
|
- name: Cache Docker layers
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: /tmp/.buildx-cache
|
path: /tmp/.buildx-cache
|
||||||
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
key: ${{ runner.os }}-buildx-frontend-test-${{ hashFiles('autogpt_platform/docker-compose.yml', 'autogpt_platform/backend/Dockerfile', 'autogpt_platform/backend/pyproject.toml', 'autogpt_platform/backend/poetry.lock') }}
|
||||||
@@ -219,7 +231,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
@@ -270,7 +282,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -278,7 +290,7 @@ jobs:
|
|||||||
run: corepack enable
|
run: corepack enable
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
12
.github/workflows/platform-fullstack-ci.yml
vendored
12
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
run: echo "key=${{ runner.os }}-pnpm-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml', 'autogpt_platform/frontend/package.json') }}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ steps.cache-key.outputs.key }}
|
key: ${{ steps.cache-key.outputs.key }}
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
types:
|
types:
|
||||||
runs-on: ubuntu-latest
|
runs-on: big-boi
|
||||||
needs: setup
|
needs: setup
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -68,7 +68,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Set up Node.js
|
- name: Set up Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v6
|
||||||
with:
|
with:
|
||||||
node-version: "22.18.0"
|
node-version: "22.18.0"
|
||||||
|
|
||||||
@@ -85,10 +85,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Run docker compose
|
- name: Run docker compose
|
||||||
run: |
|
run: |
|
||||||
docker compose -f ../docker-compose.yml --profile local --profile deps_backend up -d
|
docker compose -f ../docker-compose.yml --profile local up -d deps_backend
|
||||||
|
|
||||||
- name: Restore dependencies cache
|
- name: Restore dependencies cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.pnpm-store
|
path: ~/.pnpm-store
|
||||||
key: ${{ needs.setup.outputs.cache-key }}
|
key: ${{ needs.setup.outputs.cache-key }}
|
||||||
|
|||||||
13
.gitignore
vendored
13
.gitignore
vendored
@@ -3,7 +3,6 @@
|
|||||||
classic/original_autogpt/keys.py
|
classic/original_autogpt/keys.py
|
||||||
classic/original_autogpt/*.json
|
classic/original_autogpt/*.json
|
||||||
auto_gpt_workspace/*
|
auto_gpt_workspace/*
|
||||||
.autogpt/
|
|
||||||
*.mpeg
|
*.mpeg
|
||||||
.env
|
.env
|
||||||
# Root .env files
|
# Root .env files
|
||||||
@@ -160,10 +159,6 @@ CURRENT_BULLETIN.md
|
|||||||
|
|
||||||
# AgBenchmark
|
# AgBenchmark
|
||||||
classic/benchmark/agbenchmark/reports/
|
classic/benchmark/agbenchmark/reports/
|
||||||
classic/reports/
|
|
||||||
classic/direct_benchmark/reports/
|
|
||||||
classic/.benchmark_workspaces/
|
|
||||||
classic/direct_benchmark/.benchmark_workspaces/
|
|
||||||
|
|
||||||
# Nodejs
|
# Nodejs
|
||||||
package-lock.json
|
package-lock.json
|
||||||
@@ -182,13 +177,7 @@ autogpt_platform/backend/settings.py
|
|||||||
|
|
||||||
*.ign.*
|
*.ign.*
|
||||||
.test-contents
|
.test-contents
|
||||||
**/.claude/settings.local.json
|
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
|
.next
|
||||||
# Test database
|
|
||||||
test.db
|
|
||||||
|
|
||||||
# Next.js
|
|
||||||
.next
|
|
||||||
@@ -43,10 +43,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: poetry-install
|
- id: poetry-install
|
||||||
name: Check & Install dependencies - Classic
|
name: Check & Install dependencies - Classic - AutoGPT
|
||||||
alias: poetry-install-classic
|
alias: poetry-install-classic-autogpt
|
||||||
entry: poetry -C classic install
|
entry: poetry -C classic/original_autogpt install
|
||||||
files: ^classic/poetry\.lock$
|
# include forge source (since it's a path dependency)
|
||||||
|
files: ^classic/(original_autogpt|forge)/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Forge
|
||||||
|
alias: poetry-install-classic-forge
|
||||||
|
entry: poetry -C classic/forge install
|
||||||
|
files: ^classic/forge/poetry\.lock$
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: poetry-install
|
||||||
|
name: Check & Install dependencies - Classic - Benchmark
|
||||||
|
alias: poetry-install-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark install
|
||||||
|
files: ^classic/benchmark/poetry\.lock$
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
@@ -97,10 +116,26 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
|
|
||||||
- id: isort
|
- id: isort
|
||||||
name: Lint (isort) - Classic
|
name: Lint (isort) - Classic - AutoGPT
|
||||||
alias: isort-classic
|
alias: isort-classic-autogpt
|
||||||
entry: bash -c 'cd classic && poetry run isort $(echo "$@" | sed "s|classic/||g")' --
|
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Forge
|
||||||
|
alias: isort-classic-forge
|
||||||
|
entry: poetry -P classic/forge run isort -p forge
|
||||||
|
files: ^classic/forge/
|
||||||
|
types: [file, python]
|
||||||
|
language: system
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: Lint (isort) - Classic - Benchmark
|
||||||
|
alias: isort-classic-benchmark
|
||||||
|
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||||
|
files: ^classic/benchmark/
|
||||||
types: [file, python]
|
types: [file, python]
|
||||||
language: system
|
language: system
|
||||||
|
|
||||||
@@ -114,13 +149,26 @@ repos:
|
|||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.0.0
|
rev: 7.0.0
|
||||||
# Use consolidated flake8 config at classic/.flake8
|
# To have flake8 load the config of the individual subprojects, we have to call
|
||||||
|
# them separately.
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
name: Lint (Flake8) - Classic
|
name: Lint (Flake8) - Classic - AutoGPT
|
||||||
alias: flake8-classic
|
alias: flake8-classic-autogpt
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/
|
files: ^classic/original_autogpt/(autogpt|scripts|tests)/
|
||||||
args: [--config=classic/.flake8]
|
args: [--config=classic/original_autogpt/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Forge
|
||||||
|
alias: flake8-classic-forge
|
||||||
|
files: ^classic/forge/(forge|tests)/
|
||||||
|
args: [--config=classic/forge/.flake8]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Lint (Flake8) - Classic - Benchmark
|
||||||
|
alias: flake8-classic-benchmark
|
||||||
|
files: ^classic/benchmark/(agbenchmark|tests)/((?!reports).)*[/.]
|
||||||
|
args: [--config=classic/benchmark/.flake8]
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
@@ -156,10 +204,29 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- id: pyright
|
- id: pyright
|
||||||
name: Typecheck - Classic
|
name: Typecheck - Classic - AutoGPT
|
||||||
alias: pyright-classic
|
alias: pyright-classic-autogpt
|
||||||
entry: poetry -C classic run pyright
|
entry: poetry -C classic/original_autogpt run pyright
|
||||||
files: ^classic/(original_autogpt|forge|direct_benchmark)/.*\.py$|^classic/poetry\.lock$
|
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||||
|
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Forge
|
||||||
|
alias: pyright-classic-forge
|
||||||
|
entry: poetry -C classic/forge run pyright
|
||||||
|
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||||
|
types: [file]
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
|
||||||
|
- id: pyright
|
||||||
|
name: Typecheck - Classic - Benchmark
|
||||||
|
alias: pyright-classic-benchmark
|
||||||
|
entry: poetry -C classic/benchmark run pyright
|
||||||
|
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||||
types: [file]
|
types: [file]
|
||||||
language: system
|
language: system
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
1862
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -9,25 +9,25 @@ packages = [{ include = "autogpt_libs" }]
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.408"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.3.0"
|
||||||
pytest-mock = "^3.14.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-cov = "^6.2.1"
|
pytest-cov = "^7.0.0"
|
||||||
ruff = "^0.12.11"
|
ruff = "^0.15.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|||||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
|||||||
REVID_API_KEY=
|
REVID_API_KEY=
|
||||||
SCREENSHOTONE_API_KEY=
|
SCREENSHOTONE_API_KEY=
|
||||||
UNREAL_SPEECH_API_KEY=
|
UNREAL_SPEECH_API_KEY=
|
||||||
|
ELEVENLABS_API_KEY=
|
||||||
|
|
||||||
# Data & Search Services
|
# Data & Search Services
|
||||||
E2B_API_KEY=
|
E2B_API_KEY=
|
||||||
|
|||||||
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,3 +19,6 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
|
# Workspace files
|
||||||
|
workspaces/
|
||||||
|
|||||||
@@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python without upgrading system-managed packages
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy only necessary files from builder
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# OpenAI API Configuration
|
# OpenAI API Configuration
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="anthropic/claude-opus-4.5", description="Default model to use"
|
default="anthropic/claude-opus-4.6", description="Default model to use"
|
||||||
)
|
)
|
||||||
title_model: str = Field(
|
title_model: str = Field(
|
||||||
default="openai/gpt-4o-mini",
|
default="openai/gpt-4o-mini",
|
||||||
|
|||||||
@@ -45,10 +45,7 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(
|
return await PrismaChatSession.prisma().create(data=data)
|
||||||
data=data,
|
|
||||||
include={"Messages": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ class ResponseType(str, Enum):
|
|||||||
START = "start"
|
START = "start"
|
||||||
FINISH = "finish"
|
FINISH = "finish"
|
||||||
|
|
||||||
|
# Step lifecycle (one LLM API call within a message)
|
||||||
|
START_STEP = "start-step"
|
||||||
|
FINISH_STEP = "finish-step"
|
||||||
|
|
||||||
# Text streaming
|
# Text streaming
|
||||||
TEXT_START = "text-start"
|
TEXT_START = "text-start"
|
||||||
TEXT_DELTA = "text-delta"
|
TEXT_DELTA = "text-delta"
|
||||||
@@ -57,6 +61,16 @@ class StreamStart(StreamBaseResponse):
|
|||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-protocol fields like taskId."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"messageId": self.messageId,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
"""End of message/stream."""
|
"""End of message/stream."""
|
||||||
@@ -64,6 +78,26 @@ class StreamFinish(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.FINISH
|
type: ResponseType = ResponseType.FINISH
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStartStep(StreamBaseResponse):
|
||||||
|
"""Start of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to add a step-start boundary to message.parts,
|
||||||
|
enabling visual separation between multiple LLM calls in a single message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.START_STEP
|
||||||
|
|
||||||
|
|
||||||
|
class StreamFinishStep(StreamBaseResponse):
|
||||||
|
"""End of a step (one LLM API call within a message).
|
||||||
|
|
||||||
|
The AI SDK uses this to reset activeTextParts and activeReasoningParts,
|
||||||
|
so the next LLM call in a tool-call continuation starts with clean state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: ResponseType = ResponseType.FINISH_STEP
|
||||||
|
|
||||||
|
|
||||||
# ========== Text Streaming ==========
|
# ========== Text Streaming ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +151,7 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
type: ResponseType = ResponseType.TOOL_OUTPUT_AVAILABLE
|
||||||
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
toolCallId: str = Field(..., description="Tool call ID this responds to")
|
||||||
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
output: str | dict[str, Any] = Field(..., description="Tool execution output")
|
||||||
# Additional fields for internal use (not part of AI SDK spec but useful)
|
# Keep these for internal backend use
|
||||||
toolName: str | None = Field(
|
toolName: str | None = Field(
|
||||||
default=None, description="Name of the tool that was executed"
|
default=None, description="Name of the tool that was executed"
|
||||||
)
|
)
|
||||||
@@ -125,6 +159,17 @@ class StreamToolOutputAvailable(StreamBaseResponse):
|
|||||||
default=True, description="Whether the tool execution succeeded"
|
default=True, description="Whether the tool execution succeeded"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_sse(self) -> str:
|
||||||
|
"""Convert to SSE format, excluding non-spec fields."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"type": self.type.value,
|
||||||
|
"toolCallId": self.toolCallId,
|
||||||
|
"output": self.output,
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
# ========== Other ==========
|
# ========== Other ==========
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -17,7 +17,29 @@ from . import stream_registry
|
|||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
|
from .tools.models import (
|
||||||
|
AgentDetailsResponse,
|
||||||
|
AgentOutputResponse,
|
||||||
|
AgentPreviewResponse,
|
||||||
|
AgentSavedResponse,
|
||||||
|
AgentsFoundResponse,
|
||||||
|
BlockListResponse,
|
||||||
|
BlockOutputResponse,
|
||||||
|
ClarificationNeededResponse,
|
||||||
|
DocPageResponse,
|
||||||
|
DocSearchResultsResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ExecutionStartedResponse,
|
||||||
|
InputValidationErrorResponse,
|
||||||
|
NeedLoginResponse,
|
||||||
|
NoResultsResponse,
|
||||||
|
OperationInProgressResponse,
|
||||||
|
OperationPendingResponse,
|
||||||
|
OperationStartedResponse,
|
||||||
|
SetupRequirementsResponse,
|
||||||
|
UnderstandingUpdatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -266,12 +288,36 @@ async def stream_chat_post(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
stream_start_time = time.perf_counter()
|
||||||
|
log_meta = {"component": "ChatStream", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] stream_chat_post STARTED, session={session_id}, "
|
||||||
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - stream_start_time) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
|
log_meta["task_id"] = task_id
|
||||||
|
|
||||||
|
task_create_start = time.perf_counter()
|
||||||
await stream_registry.create_task(
|
await stream_registry.create_task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -280,14 +326,28 @@ async def stream_chat_post(
|
|||||||
tool_name="chat",
|
tool_name="chat",
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
# Background task that runs the AI generation independently of SSE connection
|
||||||
async def run_ai_generation():
|
async def run_ai_generation():
|
||||||
try:
|
import time as time_module
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
|
|
||||||
|
gen_start_time = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
first_chunk_time, ttfc = None, None
|
||||||
|
chunk_count = 0
|
||||||
|
try:
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
session_id,
|
session_id,
|
||||||
request.message,
|
request.message,
|
||||||
@@ -295,25 +355,79 @@ async def stream_chat_post(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time_module.perf_counter()
|
||||||
|
ttfc = first_chunk_time - gen_start_time
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"time_to_first_chunk_ms": ttfc * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
# Write to Redis (subscribers will receive via XREAD)
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
await stream_registry.publish_chunk(task_id, chunk)
|
||||||
|
|
||||||
# Mark task as completed
|
gen_end_time = time_module.perf_counter()
|
||||||
|
total_time = (gen_end_time - gen_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, "
|
||||||
|
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"time_to_first_chunk_ms": (
|
||||||
|
ttfc * 1000 if ttfc is not None else None
|
||||||
|
),
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
await stream_registry.mark_task_completed(task_id, "completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = time_module.perf_counter() - gen_start_time
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
bg_task = asyncio.create_task(run_ai_generation())
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||||
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
# SSE endpoint that subscribes to the task's stream
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
event_gen_start = time_module.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
|
||||||
|
f"user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
subscriber_queue = None
|
subscriber_queue = None
|
||||||
|
first_chunk_yielded = False
|
||||||
|
chunks_yielded = 0
|
||||||
try:
|
try:
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
@@ -328,22 +442,70 @@ async def stream_chat_post(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Read from the subscriber queue and yield to SSE
|
# Read from the subscriber queue and yield to SSE
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Starting to read from subscriber_queue",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
|
chunks_yielded += 1
|
||||||
|
|
||||||
|
if not first_chunk_yielded:
|
||||||
|
first_chunk_yielded = True
|
||||||
|
elapsed = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CHUNK from queue at {elapsed:.2f}s, "
|
||||||
|
f"type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
"elapsed_ms": elapsed * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk.to_sse()
|
yield chunk.to_sse()
|
||||||
|
|
||||||
# Check for finish signal
|
# Check for finish signal
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time:.2f}s; "
|
||||||
|
f"n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
yield StreamHeartbeat().to_sse()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] GeneratorExit (client disconnected), chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
"reason": "client_disconnect",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
pass # Client disconnected - background task continues
|
pass # Client disconnected - background task continues
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
elapsed = (time_module.perf_counter() - event_gen_start) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] event_generator ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
|
},
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
@@ -357,6 +519,18 @@ async def stream_chat_post(
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||||
|
total_time = time_module.perf_counter() - event_gen_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
|
||||||
|
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time * 1000,
|
||||||
|
"chunks_yielded": chunks_yielded,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -374,63 +548,90 @@ async def stream_chat_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}/stream",
|
"/sessions/{session_id}/stream",
|
||||||
)
|
)
|
||||||
async def stream_chat_get(
|
async def resume_session_stream(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
user_id: str | None = Depends(auth.get_user_id),
|
||||||
is_user_message: bool = Query(default=True),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream chat responses for a session (GET - legacy endpoint).
|
Resume an active stream for a session.
|
||||||
|
|
||||||
Streams the AI/completion responses in real time over Server-Sent Events (SSE), including:
|
Called by the AI SDK's ``useChat(resume: true)`` on page load.
|
||||||
- Text fragments as they are generated
|
Checks for an active (in-progress) task on the session and either replays
|
||||||
- Tool call UI elements (if invoked)
|
the full SSE stream or returns 204 No Content if nothing is running.
|
||||||
- Tool execution results
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier.
|
||||||
message: The user's new message to process.
|
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
is_user_message: Whether the message is a user message.
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks.
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StreamingResponse (SSE) when an active stream exists,
|
||||||
|
or 204 No Content when there is nothing to resume.
|
||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
import asyncio
|
||||||
|
|
||||||
|
active_task, _last_id = await stream_registry.get_active_task_for_session(
|
||||||
|
session_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not active_task:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||||
|
task_id=active_task.task_id,
|
||||||
|
user_id=user_id,
|
||||||
|
last_message_id="0-0", # Full replay so useChat rebuilds the message
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscriber_queue is None:
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_type: str | None = None
|
first_chunk_type: str | None = None
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
try:
|
||||||
session_id,
|
while True:
|
||||||
message,
|
try:
|
||||||
is_user_message=is_user_message,
|
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||||
user_id=user_id,
|
if chunk_count < 3:
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
logger.info(
|
||||||
):
|
"Resume stream chunk",
|
||||||
if chunk_count < 3:
|
extra={
|
||||||
logger.info(
|
"session_id": session_id,
|
||||||
"Chat stream chunk",
|
"chunk_type": str(chunk.type),
|
||||||
extra={
|
},
|
||||||
"session_id": session_id,
|
)
|
||||||
"chunk_type": str(chunk.type),
|
if not first_chunk_type:
|
||||||
},
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
|
||||||
|
if isinstance(chunk, StreamFinish):
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield StreamHeartbeat().to_sse()
|
||||||
|
except GeneratorExit:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in resume stream for session {session_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream_registry.unsubscribe_from_task(
|
||||||
|
active_task.task_id, subscriber_queue
|
||||||
)
|
)
|
||||||
if not first_chunk_type:
|
except Exception as unsub_err:
|
||||||
first_chunk_type = str(chunk.type)
|
logger.error(
|
||||||
chunk_count += 1
|
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
|
||||||
yield chunk.to_sse()
|
exc_info=True,
|
||||||
logger.info(
|
)
|
||||||
"Chat stream completed",
|
logger.info(
|
||||||
extra={
|
"Resume stream completed",
|
||||||
"session_id": session_id,
|
extra={
|
||||||
"chunk_count": chunk_count,
|
"session_id": session_id,
|
||||||
"first_chunk_type": first_chunk_type,
|
"n_chunks": chunk_count,
|
||||||
},
|
"first_chunk_type": first_chunk_type,
|
||||||
)
|
},
|
||||||
# AI SDK protocol termination
|
)
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -438,8 +639,8 @@ async def stream_chat_get(
|
|||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
"X-Accel-Buffering": "no",
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -751,3 +952,42 @@ async def health_check() -> dict:
|
|||||||
"service": "chat",
|
"service": "chat",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Schema Export (for OpenAPI / Orval codegen) ==========
|
||||||
|
|
||||||
|
ToolResponseUnion = (
|
||||||
|
AgentsFoundResponse
|
||||||
|
| NoResultsResponse
|
||||||
|
| AgentDetailsResponse
|
||||||
|
| SetupRequirementsResponse
|
||||||
|
| ExecutionStartedResponse
|
||||||
|
| NeedLoginResponse
|
||||||
|
| ErrorResponse
|
||||||
|
| InputValidationErrorResponse
|
||||||
|
| AgentOutputResponse
|
||||||
|
| UnderstandingUpdatedResponse
|
||||||
|
| AgentPreviewResponse
|
||||||
|
| AgentSavedResponse
|
||||||
|
| ClarificationNeededResponse
|
||||||
|
| BlockListResponse
|
||||||
|
| BlockOutputResponse
|
||||||
|
| DocSearchResultsResponse
|
||||||
|
| DocPageResponse
|
||||||
|
| OperationStartedResponse
|
||||||
|
| OperationPendingResponse
|
||||||
|
| OperationInProgressResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/schema/tool-responses",
|
||||||
|
response_model=ToolResponseUnion,
|
||||||
|
include_in_schema=True,
|
||||||
|
summary="[Dummy] Tool response type export for codegen",
|
||||||
|
description="This endpoint is not meant to be called. It exists solely to "
|
||||||
|
"expose tool response models in the OpenAPI schema for frontend codegen.",
|
||||||
|
)
|
||||||
|
async def _tool_response_schema() -> ToolResponseUnion: # type: ignore[return]
|
||||||
|
"""Never called at runtime. Exists only so Orval generates TS types."""
|
||||||
|
raise HTTPException(status_code=501, detail="Schema-only endpoint")
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -52,8 +52,10 @@ from .response_model import (
|
|||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -222,8 +224,18 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
# In non-production environments, fetch the latest prompt version
|
||||||
|
# instead of the production-labeled version for easier testing
|
||||||
|
label = (
|
||||||
|
None
|
||||||
|
if settings.config.app_env == AppEnvironment.PRODUCTION
|
||||||
|
else "latest"
|
||||||
|
)
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
langfuse.get_prompt,
|
||||||
|
config.langfuse_prompt_name,
|
||||||
|
label=label,
|
||||||
|
cache_ttl_seconds=0,
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -341,6 +353,10 @@ async def stream_chat_completion(
|
|||||||
retry_count: int = 0,
|
retry_count: int = 0,
|
||||||
session: ChatSession | None = None,
|
session: ChatSession | None = None,
|
||||||
context: dict[str, str] | None = None, # {url: str, content: str}
|
context: dict[str, str] | None = None, # {url: str, content: str}
|
||||||
|
_continuation_message_id: (
|
||||||
|
str | None
|
||||||
|
) = None, # Internal: reuse message ID for tool call continuations
|
||||||
|
_task_id: str | None = None, # Internal: task ID for SSE reconnection support
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||||
"""Main entry point for streaming chat completions with database handling.
|
"""Main entry point for streaming chat completions with database handling.
|
||||||
|
|
||||||
@@ -361,21 +377,45 @@ async def stream_chat_completion(
|
|||||||
ValueError: If max_context_messages is exceeded
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
completion_start = time.monotonic()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Streaming chat completion for session {session_id} for message {message} and user id {user_id}. Message is user message: {is_user_message}"
|
f"[TIMING] stream_chat_completion STARTED, session={session_id}, user={user_id}, "
|
||||||
|
f"message_len={len(message) if message else 0}, is_user={is_user_message}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"message_len": len(message) if message else 0,
|
||||||
|
"is_user_message": is_user_message,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only fetch from Redis if session not provided (initial call)
|
# Only fetch from Redis if session not provided (initial call)
|
||||||
if session is None:
|
if session is None:
|
||||||
|
fetch_start = time.monotonic()
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
|
fetch_time = (time.monotonic() - fetch_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fetched session from Redis: {session.session_id if session else 'None'}, "
|
f"[TIMING] get_chat_session took {fetch_time:.1f}ms, "
|
||||||
f"message_count={len(session.messages) if session else 0}"
|
f"n_messages={len(session.messages) if session else 0}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": fetch_time,
|
||||||
|
"n_messages": len(session.messages) if session else 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using provided session object: {session.session_id}, "
|
f"[TIMING] Using provided session, messages={len(session.messages)}",
|
||||||
f"message_count={len(session.messages)}"
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
@@ -396,17 +436,25 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Track user message in PostHog
|
# Track user message in PostHog
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
|
posthog_start = time.monotonic()
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
message_length=len(message),
|
message_length=len(message),
|
||||||
)
|
)
|
||||||
|
posthog_time = (time.monotonic() - posthog_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] track_user_message took {posthog_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": posthog_time}},
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
upsert_start = time.monotonic()
|
||||||
f"Upserting session: {session.session_id} with user id {session.user_id}, "
|
|
||||||
f"message_count={len(session.messages)}"
|
|
||||||
)
|
|
||||||
session = await upsert_chat_session(session)
|
session = await upsert_chat_session(session)
|
||||||
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": upsert_time}},
|
||||||
|
)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
@@ -444,7 +492,13 @@ async def stream_chat_completion(
|
|||||||
asyncio.create_task(_update_title())
|
asyncio.create_task(_update_title())
|
||||||
|
|
||||||
# Build system prompt with business understanding
|
# Build system prompt with business understanding
|
||||||
|
prompt_start = time.monotonic()
|
||||||
system_prompt, understanding = await _build_system_prompt(user_id)
|
system_prompt, understanding = await _build_system_prompt(user_id)
|
||||||
|
prompt_time = (time.monotonic() - prompt_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _build_system_prompt took {prompt_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": prompt_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize variables for streaming
|
# Initialize variables for streaming
|
||||||
assistant_response = ChatMessage(
|
assistant_response = ChatMessage(
|
||||||
@@ -469,13 +523,27 @@ async def stream_chat_completion(
|
|||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
|
|
||||||
message_id = str(uuid_module.uuid4())
|
is_continuation = _continuation_message_id is not None
|
||||||
|
message_id = _continuation_message_id or str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
# Yield message start
|
# Only yield message start for the initial call, not for continuations.
|
||||||
yield StreamStart(messageId=message_id)
|
setup_time = (time.monotonic() - completion_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Setup complete, yielding StreamStart at {setup_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
|
)
|
||||||
|
if not is_continuation:
|
||||||
|
yield StreamStart(messageId=message_id, taskId=_task_id)
|
||||||
|
|
||||||
|
# Emit start-step before each LLM call (AI SDK uses this to add step boundaries)
|
||||||
|
yield StreamStartStep()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Calling _stream_chat_chunks",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
async for chunk in _stream_chat_chunks(
|
async for chunk in _stream_chat_chunks(
|
||||||
session=session,
|
session=session,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
@@ -575,6 +643,10 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamFinish):
|
elif isinstance(chunk, StreamFinish):
|
||||||
|
if has_done_tool_call:
|
||||||
|
# Tool calls happened — close the step but don't send message-level finish.
|
||||||
|
# The continuation will open a new step, and finish will come at the end.
|
||||||
|
yield StreamFinishStep()
|
||||||
if not has_done_tool_call:
|
if not has_done_tool_call:
|
||||||
# Emit text-end before finish if we received text but haven't closed it
|
# Emit text-end before finish if we received text but haven't closed it
|
||||||
if has_received_text and not text_streaming_ended:
|
if has_received_text and not text_streaming_ended:
|
||||||
@@ -606,6 +678,8 @@ async def stream_chat_completion(
|
|||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
|
# Emit finish-step before finish (resets AI SDK text/reasoning state)
|
||||||
|
yield StreamFinishStep()
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(chunk, StreamError):
|
elif isinstance(chunk, StreamError):
|
||||||
has_yielded_error = True
|
has_yielded_error = True
|
||||||
@@ -618,6 +692,9 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, StreamHeartbeat):
|
||||||
|
# Pass through heartbeat to keep SSE connection alive
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -652,6 +729,10 @@ async def stream_chat_completion(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
|
||||||
)
|
)
|
||||||
|
# Close the current step before retrying so the recursive call's
|
||||||
|
# StreamStartStep doesn't produce unbalanced step events.
|
||||||
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
should_retry = True
|
should_retry = True
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
@@ -687,6 +768,7 @@ async def stream_chat_completion(
|
|||||||
error_response = StreamError(errorText=error_message)
|
error_response = StreamError(errorText=error_message)
|
||||||
yield error_response
|
yield error_response
|
||||||
if not has_yielded_end:
|
if not has_yielded_end:
|
||||||
|
yield StreamFinishStep()
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -701,6 +783,8 @@ async def stream_chat_completion(
|
|||||||
retry_count=retry_count + 1,
|
retry_count=retry_count + 1,
|
||||||
session=session,
|
session=session,
|
||||||
context=context,
|
context=context,
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID since start was already sent
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
return # Exit after retry to avoid double-saving in finally block
|
return # Exit after retry to avoid double-saving in finally block
|
||||||
@@ -770,6 +854,8 @@ async def stream_chat_completion(
|
|||||||
session=session, # Pass session object to avoid Redis refetch
|
session=session, # Pass session object to avoid Redis refetch
|
||||||
context=context,
|
context=context,
|
||||||
tool_call_response=str(tool_response_messages),
|
tool_call_response=str(tool_response_messages),
|
||||||
|
_continuation_message_id=message_id, # Reuse message ID to avoid duplicates
|
||||||
|
_task_id=_task_id,
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -880,9 +966,21 @@ async def _stream_chat_chunks(
|
|||||||
SSE formatted JSON response objects
|
SSE formatted JSON response objects
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import time as time_module
|
||||||
|
|
||||||
|
stream_chunks_start = time_module.perf_counter()
|
||||||
model = config.model
|
model = config.model
|
||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {"component": "ChatService", "session_id": session.session_id}
|
||||||
|
if session.user_id:
|
||||||
|
log_meta["user_id"] = session.user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks STARTED, session={session.session_id}, "
|
||||||
|
f"user={session.user_id}, n_messages={len(session.messages)}",
|
||||||
|
extra={"json_fields": {**log_meta, "n_messages": len(session.messages)}},
|
||||||
|
)
|
||||||
|
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -893,12 +991,18 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
|
context_start = time_module.perf_counter()
|
||||||
context_result = await _manage_context_window(
|
context_result = await _manage_context_window(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
context_time = (time_module.perf_counter() - context_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _manage_context_window took {context_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": context_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if context_result.error:
|
if context_result.error:
|
||||||
if "System prompt dropped" in context_result.error:
|
if "System prompt dropped" in context_result.error:
|
||||||
@@ -933,9 +1037,19 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
while retry_count <= MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
|
elapsed = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
retry_info = (
|
||||||
|
f" (retry {retry_count}/{MAX_RETRIES})" if retry_count > 0 else ""
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating OpenAI chat completion stream..."
|
f"[TIMING] Creating OpenAI stream at {elapsed:.1f}ms{retry_info}",
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"retry_count": retry_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||||
@@ -952,6 +1066,7 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
api_call_start = time_module.perf_counter()
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -961,6 +1076,11 @@ async def _stream_chat_chunks(
|
|||||||
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
api_init_time = (time_module.perf_counter() - api_call_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream object returned in {api_init_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": api_init_time}},
|
||||||
|
)
|
||||||
|
|
||||||
# Variables to accumulate tool calls
|
# Variables to accumulate tool calls
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
@@ -971,10 +1091,13 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# Track if we've started the text block
|
# Track if we've started the text block
|
||||||
text_started = False
|
text_started = False
|
||||||
|
first_content_chunk = True
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# Process the stream
|
# Process the stream
|
||||||
chunk: ChatCompletionChunk
|
chunk: ChatCompletionChunk
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
if chunk.usage:
|
if chunk.usage:
|
||||||
yield StreamUsage(
|
yield StreamUsage(
|
||||||
promptTokens=chunk.usage.prompt_tokens,
|
promptTokens=chunk.usage.prompt_tokens,
|
||||||
@@ -997,6 +1120,23 @@ async def _stream_chat_chunks(
|
|||||||
if not text_started and text_block_id:
|
if not text_started and text_block_id:
|
||||||
yield StreamTextStart(id=text_block_id)
|
yield StreamTextStart(id=text_block_id)
|
||||||
text_started = True
|
text_started = True
|
||||||
|
# Log timing for first content chunk
|
||||||
|
if first_content_chunk:
|
||||||
|
first_content_chunk = False
|
||||||
|
ttfc = (
|
||||||
|
time_module.perf_counter() - api_call_start
|
||||||
|
) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST CONTENT CHUNK at {ttfc:.1f}ms "
|
||||||
|
f"(since API call), n_chunks={chunk_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"time_to_first_chunk_ms": ttfc,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
# Stream the text delta
|
# Stream the text delta
|
||||||
text_response = StreamTextDelta(
|
text_response = StreamTextDelta(
|
||||||
id=text_block_id or "",
|
id=text_block_id or "",
|
||||||
@@ -1053,7 +1193,21 @@ async def _stream_chat_chunks(
|
|||||||
toolName=tool_calls[idx]["function"]["name"],
|
toolName=tool_calls[idx]["function"]["name"],
|
||||||
)
|
)
|
||||||
emitted_start_for_idx.add(idx)
|
emitted_start_for_idx.add(idx)
|
||||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
stream_duration = time_module.perf_counter() - api_call_start
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] OpenAI stream COMPLETE, finish_reason={finish_reason}, "
|
||||||
|
f"duration={stream_duration:.2f}s, "
|
||||||
|
f"n_chunks={chunk_count}, n_tool_calls={len(tool_calls)}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"stream_duration_ms": stream_duration * 1000,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"n_chunks": chunk_count,
|
||||||
|
"n_tool_calls": len(tool_calls),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Yield all accumulated tool calls after the stream is complete
|
# Yield all accumulated tool calls after the stream is complete
|
||||||
# This ensures all tool call arguments have been fully received
|
# This ensures all tool call arguments have been fully received
|
||||||
@@ -1073,6 +1227,12 @@ async def _stream_chat_chunks(
|
|||||||
# Re-raise to trigger retry logic in the parent function
|
# Re-raise to trigger retry logic in the parent function
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
|
||||||
|
f"session={session.session_id}, user={session.user_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1552,6 +1712,7 @@ async def _execute_long_running_tool_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=str(e)),
|
StreamError(errorText=str(e)),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|
||||||
await _update_pending_operation(
|
await _update_pending_operation(
|
||||||
@@ -1809,6 +1970,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish start event
|
# Publish start event
|
||||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamStartStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||||
|
|
||||||
# Stream the response
|
# Stream the response
|
||||||
@@ -1832,6 +1994,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
|
|
||||||
# Publish end events
|
# Publish end events
|
||||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -1873,4 +2036,5 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
task_id,
|
task_id,
|
||||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||||
)
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
|||||||
@@ -104,6 +104,24 @@ async def create_task(
|
|||||||
Returns:
|
Returns:
|
||||||
The created ActiveTask instance (metadata only)
|
The created ActiveTask instance (metadata only)
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata for structured logging
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
|
)
|
||||||
|
|
||||||
task = ActiveTask(
|
task = ActiveTask(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -114,10 +132,18 @@ async def create_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store metadata in Redis
|
# Store metadata in Redis
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
|
redis_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] get_redis_async took {redis_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": redis_time}},
|
||||||
|
)
|
||||||
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
op_key = _get_operation_mapping_key(operation_id)
|
||||||
|
|
||||||
|
hset_start = time.perf_counter()
|
||||||
await redis.hset( # type: ignore[misc]
|
await redis.hset( # type: ignore[misc]
|
||||||
meta_key,
|
meta_key,
|
||||||
mapping={
|
mapping={
|
||||||
@@ -131,12 +157,22 @@ async def create_task(
|
|||||||
"created_at": task.created_at.isoformat(),
|
"created_at": task.created_at.isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
hset_time = (time.perf_counter() - hset_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] redis.hset took {hset_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hset_time}},
|
||||||
|
)
|
||||||
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
await redis.expire(meta_key, config.stream_ttl)
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
# Create operation_id -> task_id mapping for webhook lookups
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] create_task COMPLETED in {total_time:.1f}ms; task={task_id}, session={session_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
|
||||||
|
)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -156,26 +192,60 @@ async def publish_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
The Redis Stream message ID
|
The Redis Stream message ID
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
chunk_type = type(chunk).__name__
|
||||||
chunk_json = chunk.model_dump_json()
|
chunk_json = chunk.model_dump_json()
|
||||||
message_id = "0-0"
|
message_id = "0-0"
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {
|
||||||
|
"component": "StreamRegistry",
|
||||||
|
"task_id": task_id,
|
||||||
|
"chunk_type": chunk_type,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
# Write to Redis Stream for persistence and real-time delivery
|
||||||
|
xadd_start = time.perf_counter()
|
||||||
raw_id = await redis.xadd(
|
raw_id = await redis.xadd(
|
||||||
stream_key,
|
stream_key,
|
||||||
{"data": chunk_json},
|
{"data": chunk_json},
|
||||||
maxlen=config.stream_max_length,
|
maxlen=config.stream_max_length,
|
||||||
)
|
)
|
||||||
|
xadd_time = (time.perf_counter() - xadd_start) * 1000
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
# Set TTL on stream to match task metadata TTL
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
await redis.expire(stream_key, config.stream_ttl)
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
# Only log timing for significant chunks or slow operations
|
||||||
|
if (
|
||||||
|
chunk_type
|
||||||
|
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||||
|
or total_time > 50
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] publish_chunk {chunk_type} in {total_time:.1f}ms (xadd={xadd_time:.1f}ms)",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"xadd_time_ms": xadd_time,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
f"[TIMING] Failed to publish chunk {chunk_type} after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -200,24 +270,61 @@ async def subscribe_to_task(
|
|||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||||
or user doesn't have access
|
or user doesn't have access
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Build log metadata
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
if user_id:
|
||||||
|
log_meta["user_id"] = user_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task STARTED, task={task_id}, user={user_id}, last_msg={last_message_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_message_id": last_message_id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_start = time.perf_counter()
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
meta_key = _get_task_meta_key(task_id)
|
meta_key = _get_task_meta_key(task_id)
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||||
|
hgetall_time = (time.perf_counter() - redis_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis hgetall took {hgetall_time:.1f}ms",
|
||||||
|
extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}},
|
||||||
|
)
|
||||||
|
|
||||||
if not meta:
|
if not meta:
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task not found in Redis after {elapsed:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"reason": "task_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||||
task_status = meta.get("status", "")
|
task_status = meta.get("status", "")
|
||||||
task_user_id = meta.get("user_id", "") or None
|
task_user_id = meta.get("user_id", "") or None
|
||||||
|
log_meta["session_id"] = meta.get("session_id", "")
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
# Validate ownership - if task has an owner, requester must match
|
||||||
if task_user_id:
|
if task_user_id:
|
||||||
if user_id != task_user_id:
|
if user_id != task_user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {user_id} denied access to task {task_id} "
|
f"[TIMING] Access denied: user {user_id} tried to access task owned by {task_user_id}",
|
||||||
f"owned by {task_user_id}"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"task_owner": task_user_id,
|
||||||
|
"reason": "access_denied",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -225,7 +332,19 @@ async def subscribe_to_task(
|
|||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
# Step 1: Replay messages from Redis Stream
|
||||||
|
xread_start = time.perf_counter()
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={task_status}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"task_status": task_status,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
replayed_count = 0
|
replayed_count = 0
|
||||||
replay_last_id = last_message_id
|
replay_last_id = last_message_id
|
||||||
@@ -244,19 +363,48 @@ async def subscribe_to_task(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
logger.warning(f"Failed to replay message: {e}")
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
logger.info(
|
||||||
|
f"[TIMING] Replayed {replayed_count} messages, last_id={replay_last_id}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
"replay_last_id": replay_last_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
# Step 2: If task is still running, start stream listener for live updates
|
||||||
if task_status == "running":
|
if task_status == "running":
|
||||||
|
logger.info(
|
||||||
|
"[TIMING] Task still running, starting _stream_listener",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
listener_task = asyncio.create_task(
|
listener_task = asyncio.create_task(
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
_stream_listener(task_id, subscriber_queue, replay_last_id, log_meta)
|
||||||
)
|
)
|
||||||
# Track listener task for cleanup on unsubscribe
|
# Track listener task for cleanup on unsubscribe
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||||
else:
|
else:
|
||||||
# Task is completed/failed - add finish marker
|
# Task is completed/failed - add finish marker
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] Task already {task_status}, adding StreamFinish",
|
||||||
|
extra={"json_fields": {**log_meta, "task_status": task_status}},
|
||||||
|
)
|
||||||
await subscriber_queue.put(StreamFinish())
|
await subscriber_queue.put(StreamFinish())
|
||||||
|
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] subscribe_to_task COMPLETED in {total_time:.1f}ms; task={task_id}, "
|
||||||
|
f"n_messages_replayed={replayed_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"n_messages_replayed": replayed_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return subscriber_queue
|
return subscriber_queue
|
||||||
|
|
||||||
|
|
||||||
@@ -264,6 +412,7 @@ async def _stream_listener(
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||||
last_replayed_id: str,
|
last_replayed_id: str,
|
||||||
|
log_meta: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||||
|
|
||||||
@@ -274,10 +423,27 @@ async def _stream_listener(
|
|||||||
task_id: Task ID to listen for
|
task_id: Task ID to listen for
|
||||||
subscriber_queue: Queue to deliver messages to
|
subscriber_queue: Queue to deliver messages to
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
last_replayed_id: Last message ID from replay (continue from here)
|
||||||
|
log_meta: Structured logging metadata
|
||||||
"""
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Use provided log_meta or build minimal one
|
||||||
|
if log_meta is None:
|
||||||
|
log_meta = {"component": "StreamRegistry", "task_id": task_id}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener STARTED, task={task_id}, last_id={last_replayed_id}",
|
||||||
|
extra={"json_fields": {**log_meta, "last_replayed_id": last_replayed_id}},
|
||||||
|
)
|
||||||
|
|
||||||
queue_id = id(subscriber_queue)
|
queue_id = id(subscriber_queue)
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
# Track the last successfully delivered message ID for recovery hints
|
||||||
last_delivered_id = last_replayed_id
|
last_delivered_id = last_replayed_id
|
||||||
|
messages_delivered = 0
|
||||||
|
first_message_time = None
|
||||||
|
xread_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = await get_redis_async()
|
redis = await get_redis_async()
|
||||||
@@ -287,9 +453,39 @@ async def _stream_listener(
|
|||||||
while True:
|
while True:
|
||||||
# Block for up to 30 seconds waiting for new messages
|
# Block for up to 30 seconds waiting for new messages
|
||||||
# This allows periodic checking if task is still running
|
# This allows periodic checking if task is still running
|
||||||
|
xread_start = time.perf_counter()
|
||||||
|
xread_count += 1
|
||||||
messages = await redis.xread(
|
messages = await redis.xread(
|
||||||
{stream_key: current_id}, block=30000, count=100
|
{stream_key: current_id}, block=30000, count=100
|
||||||
)
|
)
|
||||||
|
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
msg_count = sum(len(msgs) for _, msgs in messages)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} returned {msg_count} messages in {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"n_messages": msg_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif xread_time > 1000:
|
||||||
|
# Only log timeouts (30s blocking)
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] xread #{xread_count} timeout after {xread_time:.1f}ms",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
"duration_ms": xread_time,
|
||||||
|
"reason": "timeout",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# Timeout - check if task is still running
|
# Timeout - check if task is still running
|
||||||
@@ -326,10 +522,30 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
# Update last delivered ID on successful delivery
|
# Update last delivered ID on successful delivery
|
||||||
last_delivered_id = current_id
|
last_delivered_id = current_id
|
||||||
|
messages_delivered += 1
|
||||||
|
if first_message_time is None:
|
||||||
|
first_message_time = time.perf_counter()
|
||||||
|
elapsed = (first_message_time - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] FIRST live message at {elapsed:.1f}ms, type={type(chunk).__name__}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"chunk_type": type(chunk).__name__,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Subscriber queue full for task {task_id}, "
|
f"[TIMING] Subscriber queue full, delivery timed out after {QUEUE_PUT_TIMEOUT}s",
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"timeout_s": QUEUE_PUT_TIMEOUT,
|
||||||
|
"reason": "queue_full",
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
# Send overflow error with recovery info
|
# Send overflow error with recovery info
|
||||||
try:
|
try:
|
||||||
@@ -351,15 +567,44 @@ async def _stream_listener(
|
|||||||
|
|
||||||
# Stop listening on finish
|
# Stop listening on finish
|
||||||
if isinstance(chunk, StreamFinish):
|
if isinstance(chunk, StreamFinish):
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] StreamFinish received in {total_time/1000:.1f}s; delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
logger.warning(
|
||||||
|
f"Error processing stream message: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "error": str(e)}},
|
||||||
|
)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener CANCELLED after {elapsed:.1f}ms, delivered={messages_delivered}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"elapsed_ms": elapsed,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"reason": "cancelled",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
raise # Re-raise to propagate cancellation
|
raise # Re-raise to propagate cancellation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
elapsed = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[TIMING] _stream_listener ERROR after {elapsed:.1f}ms: {e}",
|
||||||
|
extra={"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}},
|
||||||
|
)
|
||||||
# On error, send finish to unblock subscriber
|
# On error, send finish to unblock subscriber
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@@ -368,10 +613,24 @@ async def _stream_listener(
|
|||||||
)
|
)
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
"Could not deliver finish event after error",
|
||||||
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Clean up listener task mapping on exit
|
# Clean up listener task mapping on exit
|
||||||
|
total_time = (time.perf_counter() - start_time) * 1000
|
||||||
|
logger.info(
|
||||||
|
f"[TIMING] _stream_listener FINISHED in {total_time/1000:.1f}s; task={task_id}, "
|
||||||
|
f"delivered={messages_delivered}, xread_count={xread_count}",
|
||||||
|
extra={
|
||||||
|
"json_fields": {
|
||||||
|
**log_meta,
|
||||||
|
"total_time_ms": total_time,
|
||||||
|
"messages_delivered": messages_delivered,
|
||||||
|
"xread_count": xread_count,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
_listener_tasks.pop(queue_id, None)
|
_listener_tasks.pop(queue_id, None)
|
||||||
|
|
||||||
|
|
||||||
@@ -598,8 +857,10 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
ResponseType,
|
ResponseType,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
|
StreamFinishStep,
|
||||||
StreamHeartbeat,
|
StreamHeartbeat,
|
||||||
StreamStart,
|
StreamStart,
|
||||||
|
StreamStartStep,
|
||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamTextEnd,
|
StreamTextEnd,
|
||||||
StreamTextStart,
|
StreamTextStart,
|
||||||
@@ -613,6 +874,8 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||||
ResponseType.START.value: StreamStart,
|
ResponseType.START.value: StreamStart,
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
ResponseType.FINISH.value: StreamFinish,
|
||||||
|
ResponseType.START_STEP.value: StreamStartStep,
|
||||||
|
ResponseType.FINISH_STEP.value: StreamFinishStep,
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
ResponseType.TEXT_START.value: StreamTextStart,
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||||
|
|||||||
@@ -7,15 +7,7 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import (
|
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||||
Graph,
|
|
||||||
Link,
|
|
||||||
Node,
|
|
||||||
create_graph,
|
|
||||||
get_graph,
|
|
||||||
get_graph_all_versions,
|
|
||||||
get_store_listed_graphs,
|
|
||||||
)
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -28,8 +20,6 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -669,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _reassign_node_ids(graph: Graph) -> None:
|
|
||||||
"""Reassign all node and link IDs to new UUIDs.
|
|
||||||
|
|
||||||
This is needed when creating a new version to avoid unique constraint violations.
|
|
||||||
"""
|
|
||||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
|
||||||
|
|
||||||
for node in graph.nodes:
|
|
||||||
node.id = id_map[node.id]
|
|
||||||
|
|
||||||
for link in graph.links:
|
|
||||||
link.id = str(uuid.uuid4())
|
|
||||||
if link.source_id in id_map:
|
|
||||||
link.source_id = id_map[link.source_id]
|
|
||||||
if link.sink_id in id_map:
|
|
||||||
link.sink_id = id_map[link.sink_id]
|
|
||||||
|
|
||||||
|
|
||||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
|
||||||
"""Populate user_id in AgentExecutorBlock nodes.
|
|
||||||
|
|
||||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
|
||||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_json: Agent JSON dict (modified in place)
|
|
||||||
user_id: User ID to set
|
|
||||||
"""
|
|
||||||
for node in agent_json.get("nodes", []):
|
|
||||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
|
||||||
input_default = node.get("input_default") or {}
|
|
||||||
if not input_default.get("user_id"):
|
|
||||||
input_default["user_id"] = user_id
|
|
||||||
node["input_default"] = input_default
|
|
||||||
logger.debug(
|
|
||||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -721,35 +672,10 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
|
||||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
if graph.id:
|
return await library_db.update_graph_in_library(graph, user_id)
|
||||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
return await library_db.create_graph_in_library(graph, user_id)
|
||||||
if existing_versions:
|
|
||||||
latest_version = max(v.version for v in existing_versions)
|
|
||||||
graph.version = latest_version + 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
|
||||||
else:
|
|
||||||
graph.id = str(uuid.uuid4())
|
|
||||||
graph.version = 1
|
|
||||||
_reassign_node_ids(graph)
|
|
||||||
logger.info(f"Creating new agent with ID {graph.id}")
|
|
||||||
|
|
||||||
created_graph = await create_graph(graph, user_id)
|
|
||||||
|
|
||||||
library_agents = await library_db.create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library."
|
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents."
|
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -13,10 +13,32 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
from backend.data.block import get_block
|
from backend.data.block import BlockType, get_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TARGET_RESULTS = 10
|
||||||
|
# Over-fetch to compensate for post-hoc filtering of graph-only blocks.
|
||||||
|
# 40 is 2x current removed; speed of query 10 vs 40 is minimial
|
||||||
|
_OVERFETCH_PAGE_SIZE = 40
|
||||||
|
|
||||||
|
# Block types that only work within graphs and cannot run standalone in CoPilot.
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES = {
|
||||||
|
BlockType.INPUT, # Graph interface definition - data enters via chat, not graph inputs
|
||||||
|
BlockType.OUTPUT, # Graph interface definition - data exits via chat, not graph outputs
|
||||||
|
BlockType.WEBHOOK, # Wait for external events - would hang forever in CoPilot
|
||||||
|
BlockType.WEBHOOK_MANUAL, # Same as WEBHOOK
|
||||||
|
BlockType.NOTE, # Visual annotation only - no runtime behavior
|
||||||
|
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
|
||||||
|
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS = {
|
||||||
|
# SmartDecisionMakerBlock - dynamically discovers downstream blocks via graph topology
|
||||||
|
"3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
@@ -88,7 +110,7 @@ class FindBlockTool(BaseTool):
|
|||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=_OVERFETCH_PAGE_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
@@ -108,60 +130,90 @@ class FindBlockTool(BaseTool):
|
|||||||
block = get_block(block_id)
|
block = get_block(block_id)
|
||||||
|
|
||||||
# Skip disabled blocks
|
# Skip disabled blocks
|
||||||
if block and not block.disabled:
|
if not block or block.disabled:
|
||||||
# Get input/output schemas
|
continue
|
||||||
input_schema = {}
|
|
||||||
output_schema = {}
|
|
||||||
try:
|
|
||||||
input_schema = block.input_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
output_schema = block.output_schema.jsonschema()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get categories from block instance
|
# Skip blocks excluded from CoPilot (graph-only blocks)
|
||||||
categories = []
|
if (
|
||||||
if hasattr(block, "categories") and block.categories:
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
categories = [cat.value for cat in block.categories]
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Extract required inputs for easier use
|
# Get input/output schemas
|
||||||
required_inputs: list[BlockInputFieldInfo] = []
|
input_schema = {}
|
||||||
if input_schema:
|
output_schema = {}
|
||||||
properties = input_schema.get("properties", {})
|
try:
|
||||||
required_fields = set(input_schema.get("required", []))
|
input_schema = block.input_schema.jsonschema()
|
||||||
# Get credential field names to exclude from required inputs
|
except Exception as e:
|
||||||
credentials_fields = set(
|
logger.debug(
|
||||||
block.input_schema.get_credentials_fields().keys()
|
"Failed to generate input schema for block %s: %s",
|
||||||
)
|
block_id,
|
||||||
|
e,
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields - they're handled separately
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
required_inputs.append(
|
|
||||||
BlockInputFieldInfo(
|
|
||||||
name=field_name,
|
|
||||||
type=field_schema.get("type", "string"),
|
|
||||||
description=field_schema.get("description", ""),
|
|
||||||
required=field_name in required_fields,
|
|
||||||
default=field_schema.get("default"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
blocks.append(
|
|
||||||
BlockInfoSummary(
|
|
||||||
id=block_id,
|
|
||||||
name=block.name,
|
|
||||||
description=block.description or "",
|
|
||||||
categories=categories,
|
|
||||||
input_schema=input_schema,
|
|
||||||
output_schema=output_schema,
|
|
||||||
required_inputs=required_inputs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
output_schema = block.output_schema.jsonschema()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to generate output schema for block %s: %s",
|
||||||
|
block_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get categories from block instance
|
||||||
|
categories = []
|
||||||
|
if hasattr(block, "categories") and block.categories:
|
||||||
|
categories = [cat.value for cat in block.categories]
|
||||||
|
|
||||||
|
# Extract required inputs for easier use
|
||||||
|
required_inputs: list[BlockInputFieldInfo] = []
|
||||||
|
if input_schema:
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = set(input_schema.get("required", []))
|
||||||
|
# Get credential field names to exclude from required inputs
|
||||||
|
credentials_fields = set(
|
||||||
|
block.input_schema.get_credentials_fields().keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
# Skip credential fields - they're handled separately
|
||||||
|
if field_name in credentials_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
required_inputs.append(
|
||||||
|
BlockInputFieldInfo(
|
||||||
|
name=field_name,
|
||||||
|
type=field_schema.get("type", "string"),
|
||||||
|
description=field_schema.get("description", ""),
|
||||||
|
required=field_name in required_fields,
|
||||||
|
default=field_schema.get("default"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(
|
||||||
|
BlockInfoSummary(
|
||||||
|
id=block_id,
|
||||||
|
name=block.name,
|
||||||
|
description=block.description or "",
|
||||||
|
categories=categories,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
required_inputs=required_inputs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(blocks) >= _TARGET_RESULTS:
|
||||||
|
break
|
||||||
|
|
||||||
|
if blocks and len(blocks) < _TARGET_RESULTS:
|
||||||
|
logger.debug(
|
||||||
|
"find_block returned %d/%d results for query '%s' "
|
||||||
|
"(filtered %d excluded/disabled blocks)",
|
||||||
|
len(blocks),
|
||||||
|
_TARGET_RESULTS,
|
||||||
|
query,
|
||||||
|
len(results) - len(blocks),
|
||||||
|
)
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
FindBlockTool,
|
||||||
|
)
|
||||||
|
from backend.api.features.chat.tools.models import BlockListResponse
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.description = f"{name} description"
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields.return_value = {}
|
||||||
|
mock.output_schema = MagicMock()
|
||||||
|
mock.output_schema.jsonschema.return_value = {}
|
||||||
|
mock.categories = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBlockFiltering:
|
||||||
|
"""Tests for block filtering in FindBlockTool."""
|
||||||
|
|
||||||
|
def test_excluded_block_types_contains_expected_types(self):
|
||||||
|
"""Verify COPILOT_EXCLUDED_BLOCK_TYPES contains all graph-only types."""
|
||||||
|
assert BlockType.INPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.OUTPUT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.WEBHOOK_MANUAL in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.NOTE in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.HUMAN_IN_THE_LOOP in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
assert BlockType.AGENT in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
|
||||||
|
def test_excluded_block_ids_contains_smart_decision_maker(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is in COPILOT_EXCLUDED_BLOCK_IDS."""
|
||||||
|
assert "3b191d9f-356f-482d-8238-ba04b6d18381" in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_filtered_from_results(self):
|
||||||
|
"""Verify blocks with excluded BlockTypes are filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
# Mock search returns an INPUT block (excluded) and a STANDARD block (included)
|
||||||
|
search_results = [
|
||||||
|
{"content_id": "input-block-id", "score": 0.9},
|
||||||
|
{"content_id": "standard-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-block-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
"input-block-id": input_block,
|
||||||
|
"standard-block-id": standard_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the standard block, not the INPUT block
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "standard-block-id"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_filtered_from_results(self):
|
||||||
|
"""Verify SmartDecisionMakerBlock is filtered from search results."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
search_results = [
|
||||||
|
{"content_id": smart_decision_id, "score": 0.9},
|
||||||
|
{"content_id": "normal-block-id", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
# SmartDecisionMakerBlock has STANDARD type but is excluded by ID
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
normal_block = make_mock_block(
|
||||||
|
"normal-block-id", "Normal Block", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_get_block(block_id):
|
||||||
|
return {
|
||||||
|
smart_decision_id: smart_block,
|
||||||
|
"normal-block-id": normal_block,
|
||||||
|
}.get(block_id)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=(search_results, 2),
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.find_block.get_block",
|
||||||
|
side_effect=mock_get_block,
|
||||||
|
):
|
||||||
|
tool = FindBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID, session=session, query="decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return normal block, not SmartDecisionMakerBlock
|
||||||
|
assert isinstance(response, BlockListResponse)
|
||||||
|
assert len(response.blocks) == 1
|
||||||
|
assert response.blocks[0].id == "normal-block-id"
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema."""
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -261,7 +262,7 @@ class RunAgentTool(BaseTool):
|
|||||||
),
|
),
|
||||||
requirements={
|
requirements={
|
||||||
"credentials": requirements_creds_list,
|
"credentials": requirements_creds_list,
|
||||||
"inputs": self._get_inputs_list(graph.input_schema),
|
"inputs": get_inputs_from_schema(graph.input_schema),
|
||||||
"execution_modes": self._get_execution_modes(graph),
|
"execution_modes": self._get_execution_modes(graph),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -369,22 +370,6 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
"""Extract inputs list from schema."""
|
|
||||||
inputs_list = []
|
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
trigger_info = graph.trigger_setup_info
|
trigger_info = graph.trigger_setup_info
|
||||||
@@ -398,7 +383,7 @@ class RunAgentTool(BaseTool):
|
|||||||
suffix: str,
|
suffix: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a message describing available inputs for an agent."""
|
"""Build a message describing available inputs for an agent."""
|
||||||
inputs_list = self._get_inputs_list(graph.input_schema)
|
inputs_list = get_inputs_from_schema(graph.input_schema)
|
||||||
required_names = [i["name"] for i in inputs_list if i["required"]]
|
required_names = [i["name"] for i in inputs_list if i["required"]]
|
||||||
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
optional_names = [i["name"] for i in inputs_list if not i["required"]]
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,19 @@ from typing import Any
|
|||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.api.features.chat.tools.find_block import (
|
||||||
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
|
)
|
||||||
|
from backend.data.block import AnyBlockSchema, get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -24,7 +29,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -73,91 +81,6 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _check_block_credentials(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
block: Any,
|
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
|
||||||
"""
|
|
||||||
Check if user has required credentials for a block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[matched_credentials, missing_credentials]
|
|
||||||
"""
|
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
|
||||||
|
|
||||||
if not credentials_fields_info:
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
effective_field_info = field_info
|
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in effective_field_info.provider
|
|
||||||
and cred.type in effective_field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
@@ -212,11 +135,24 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if block is excluded from CoPilot (graph-only blocks)
|
||||||
|
if (
|
||||||
|
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
|
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||||
|
):
|
||||||
|
return ErrorResponse(
|
||||||
|
message=(
|
||||||
|
f"Block '{block.name}' cannot be run directly in CoPilot. "
|
||||||
|
"This block is designed for use within graphs only."
|
||||||
|
),
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = (
|
||||||
user_id, block, input_data
|
await self._resolve_block_credentials(user_id, block, input_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -345,29 +281,75 @@ class RunBlockTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
async def _resolve_block_credentials(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
block: AnyBlockSchema,
|
||||||
|
input_data: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Resolve credentials for a block by matching user's available credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
block: Block to resolve credentials for
|
||||||
|
input_data: Input data for the block (used to determine provider via discriminator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (matched_credentials, missing_credentials) - matched credentials
|
||||||
|
are used for block execution, missing ones indicate setup requirements.
|
||||||
|
"""
|
||||||
|
input_data = input_data or {}
|
||||||
|
requirements = self._resolve_discriminated_credentials(block, input_data)
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
|
|
||||||
|
def _get_inputs_list(self, block: AnyBlockSchema) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
|
|
||||||
for field_name, field_schema in properties.items():
|
def _resolve_discriminated_credentials(
|
||||||
# Skip credential fields
|
self,
|
||||||
if field_name in credentials_fields:
|
block: AnyBlockSchema,
|
||||||
continue
|
input_data: dict[str, Any],
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""Resolve credential requirements, applying discriminator logic where needed."""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
if not credentials_fields_info:
|
||||||
|
return {}
|
||||||
|
|
||||||
inputs_list.append(
|
resolved: dict[str, CredentialsFieldInfo] = {}
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
|
effective_field_info = field_info
|
||||||
|
|
||||||
|
if field_info.discriminator and field_info.discriminator_mapping:
|
||||||
|
discriminator_value = input_data.get(field_info.discriminator)
|
||||||
|
if discriminator_value is None:
|
||||||
|
field = block.input_schema.model_fields.get(
|
||||||
|
field_info.discriminator
|
||||||
|
)
|
||||||
|
if field and field.default is not PydanticUndefined:
|
||||||
|
discriminator_value = field.default
|
||||||
|
|
||||||
|
if (
|
||||||
|
discriminator_value
|
||||||
|
and discriminator_value in field_info.discriminator_mapping
|
||||||
|
):
|
||||||
|
effective_field_info = field_info.discriminate(discriminator_value)
|
||||||
|
# For host-scoped credentials, add the discriminator value
|
||||||
|
# (e.g., URL) so _credential_is_for_host can match it
|
||||||
|
effective_field_info.discriminator_values.add(discriminator_value)
|
||||||
|
logger.debug(
|
||||||
|
f"Discriminated provider for {field_name}: "
|
||||||
|
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved[field_name] = effective_field_info
|
||||||
|
|
||||||
|
return resolved
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.api.features.chat.tools.models import ErrorResponse
|
||||||
|
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||||
|
from backend.data.block import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
|
||||||
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_block(
|
||||||
|
block_id: str, name: str, block_type: BlockType, disabled: bool = False
|
||||||
|
):
|
||||||
|
"""Create a mock block for testing."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.id = block_id
|
||||||
|
mock.name = name
|
||||||
|
mock.block_type = block_type
|
||||||
|
mock.disabled = disabled
|
||||||
|
mock.input_schema = MagicMock()
|
||||||
|
mock.input_schema.jsonschema.return_value = {"properties": {}, "required": []}
|
||||||
|
mock.input_schema.get_credentials_fields_info.return_value = []
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBlockFiltering:
|
||||||
|
"""Tests for block execution guards in RunBlockTool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_type_returns_error(self):
|
||||||
|
"""Attempting to execute a block with excluded BlockType returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=input_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="input-block-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
assert "designed for use within graphs only" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_excluded_block_id_returns_error(self):
|
||||||
|
"""Attempting to execute SmartDecisionMakerBlock returns error."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||||
|
smart_block = make_mock_block(
|
||||||
|
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=smart_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id=smart_decision_id,
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert "cannot be run directly in CoPilot" in response.message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_non_excluded_block_passes_guard(self):
|
||||||
|
"""Non-excluded blocks pass the filtering guard (may fail later for other reasons)."""
|
||||||
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
|
|
||||||
|
standard_block = make_mock_block(
|
||||||
|
"standard-id", "HTTP Request", BlockType.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.api.features.chat.tools.run_block.get_block",
|
||||||
|
return_value=standard_block,
|
||||||
|
):
|
||||||
|
tool = RunBlockTool()
|
||||||
|
response = await tool._execute(
|
||||||
|
user_id=_TEST_USER_ID,
|
||||||
|
session=session,
|
||||||
|
block_id="standard-id",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT be an ErrorResponse about CoPilot exclusion
|
||||||
|
# (may be other errors like missing credentials, but not the exclusion guard)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
assert "cannot be run directly in CoPilot" not in response.message
|
||||||
@@ -6,9 +6,9 @@ from typing import Any
|
|||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
|
Credentials,
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
@@ -44,14 +44,8 @@ async def fetch_graph_from_store_slug(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph_meta = await store_db.get_available_graph(
|
graph = await store_db.get_available_graph(
|
||||||
store_agent.store_listing_version_id
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
|
||||||
graph = await graph_db.get_graph(
|
|
||||||
graph_id=graph_meta.id,
|
|
||||||
version=graph_meta.version,
|
|
||||||
user_id=None, # Public access
|
|
||||||
include_subgraphs=True,
|
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
|
|
||||||
@@ -128,7 +122,7 @@ def build_missing_credentials_from_graph(
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
for field_key, (field_info, _node_fields) in aggregated_fields.items()
|
for field_key, (field_info, _, _) in aggregated_fields.items()
|
||||||
if field_key not in matched_keys
|
if field_key not in matched_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,6 +224,99 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list[Credentials]:
|
||||||
|
"""Get all available credentials for a user."""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list[Credentials],
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
) -> Credentials | None:
|
||||||
|
"""Find a credential that matches the required provider, type, scopes, and host."""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if cred.type == "oauth2" and not _credential_has_required_scopes(
|
||||||
|
cred, field_info
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if cred.type == "host_scoped" and not _credential_is_for_host(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred: Credentials,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""Create a CredentialsMetaInput from a matched credential."""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -269,7 +356,8 @@ async def match_user_credentials_to_graph(
|
|||||||
# provider is in the set of acceptable providers.
|
# provider is in the set of acceptable providers.
|
||||||
for credential_field_name, (
|
for credential_field_name, (
|
||||||
credential_requirements,
|
credential_requirements,
|
||||||
_node_fields,
|
_,
|
||||||
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, and scopes
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
@@ -337,8 +425,6 @@ def _credential_has_required_scopes(
|
|||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check that credential scopes are a superset of required scopes
|
|
||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
||||||
|
on_graph_activate,
|
||||||
|
on_graph_deactivate,
|
||||||
|
)
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -371,7 +374,7 @@ async def get_library_agent_by_graph_id(
|
|||||||
|
|
||||||
|
|
||||||
async def add_generated_agent_image(
|
async def add_generated_agent_image(
|
||||||
graph: graph_db.BaseGraph,
|
graph: graph_db.GraphBaseMeta,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
) -> Optional[prisma.models.LibraryAgent]:
|
) -> Optional[prisma.models.LibraryAgent]:
|
||||||
@@ -537,6 +540,92 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new graph and add it to the user's library."""
|
||||||
|
graph.version = 1
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agents = await create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_graph_in_library(
|
||||||
|
graph: graph_db.Graph,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
||||||
|
"""Create a new version of an existing graph and update the library entry."""
|
||||||
|
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
||||||
|
current_active_version = (
|
||||||
|
next((v for v in existing_versions if v.is_active), None)
|
||||||
|
if existing_versions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
graph.version = (
|
||||||
|
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_model = graph_db.make_graph_model(graph, user_id)
|
||||||
|
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
|
|
||||||
|
created_graph = await graph_db.create_graph(graph_model, user_id)
|
||||||
|
|
||||||
|
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
||||||
|
if not library_agent:
|
||||||
|
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
||||||
|
|
||||||
|
library_agent = await update_library_agent_version_and_settings(
|
||||||
|
user_id, created_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
if created_graph.is_active:
|
||||||
|
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
||||||
|
await graph_db.set_graph_active_version(
|
||||||
|
graph_id=created_graph.id,
|
||||||
|
version=created_graph.version,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if current_active_version:
|
||||||
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
return created_graph, library_agent
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
"""Update library agent to point to new graph version and sync settings."""
|
||||||
|
library = await update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, overload
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
@@ -11,8 +11,8 @@ import prisma.types
|
|||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data.db import transaction
|
||||||
from backend.data.graph import (
|
from backend.data.graph import (
|
||||||
GraphMeta,
|
|
||||||
GraphModel,
|
GraphModel,
|
||||||
|
GraphModelWithoutNodes,
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_as_admin,
|
get_graph_as_admin,
|
||||||
get_sub_graphs,
|
get_sub_graphs,
|
||||||
@@ -334,7 +334,22 @@ async def get_store_agent_details(
|
|||||||
raise DatabaseError("Failed to fetch agent details") from e
|
raise DatabaseError("Failed to fetch agent details") from e
|
||||||
|
|
||||||
|
|
||||||
async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[False]
|
||||||
|
) -> GraphModel: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str, hide_nodes: Literal[True] = True
|
||||||
|
) -> GraphModelWithoutNodes: ...
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_graph(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
hide_nodes: bool = True,
|
||||||
|
) -> GraphModelWithoutNodes | GraphModel:
|
||||||
try:
|
try:
|
||||||
# Get avaialble, non-deleted store listing version
|
# Get avaialble, non-deleted store listing version
|
||||||
store_listing_version = (
|
store_listing_version = (
|
||||||
@@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
"isAvailable": True,
|
"isAvailable": True,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
},
|
},
|
||||||
include={"AgentGraph": {"include": {"Nodes": True}}},
|
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta:
|
|||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return GraphModel.from_db(store_listing_version.AgentGraph).meta()
|
return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db(
|
||||||
|
store_listing_version.AgentGraph
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting agent: {e}")
|
logger.error(f"Error getting agent: {e}")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ Includes BM25 reranking for improved lexical relevance.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -362,7 +363,11 @@ async def unified_hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
try:
|
||||||
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
# Apply BM25 reranking
|
# Apply BM25 reranking
|
||||||
@@ -686,7 +691,11 @@ async def hybrid_search(
|
|||||||
LIMIT {limit_param} OFFSET {offset_param}
|
LIMIT {limit_param} OFFSET {offset_param}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await query_raw_with_schema(sql_query, *params)
|
try:
|
||||||
|
results = await query_raw_with_schema(sql_query, *params)
|
||||||
|
except Exception as e:
|
||||||
|
await _log_vector_error_diagnostics(e)
|
||||||
|
raise
|
||||||
|
|
||||||
total = results[0]["total_count"] if results else 0
|
total = results[0]["total_count"] if results else 0
|
||||||
|
|
||||||
@@ -718,6 +727,87 @@ async def hybrid_search_simple(
|
|||||||
return await hybrid_search(query=query, page=page, page_size=page_size)
|
return await hybrid_search(query=query, page=page, page_size=page_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Diagnostics
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Rate limit: only log vector error diagnostics once per this interval
|
||||||
|
_VECTOR_DIAG_INTERVAL_SECONDS = 60
|
||||||
|
_last_vector_diag_time: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def _log_vector_error_diagnostics(error: Exception) -> None:
|
||||||
|
"""Log diagnostic info when 'type vector does not exist' error occurs.
|
||||||
|
|
||||||
|
Note: Diagnostic queries use query_raw_with_schema which may run on a different
|
||||||
|
pooled connection than the one that failed. Session-level search_path can differ,
|
||||||
|
so these diagnostics show cluster-wide state, not necessarily the failed session.
|
||||||
|
|
||||||
|
Includes rate limiting to avoid log spam - only logs once per minute.
|
||||||
|
Caller should re-raise the error after calling this function.
|
||||||
|
"""
|
||||||
|
global _last_vector_diag_time
|
||||||
|
|
||||||
|
# Check if this is the vector type error
|
||||||
|
error_str = str(error).lower()
|
||||||
|
if not (
|
||||||
|
"type" in error_str and "vector" in error_str and "does not exist" in error_str
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limit: only log once per interval
|
||||||
|
now = time.time()
|
||||||
|
if now - _last_vector_diag_time < _VECTOR_DIAG_INTERVAL_SECONDS:
|
||||||
|
return
|
||||||
|
_last_vector_diag_time = now
|
||||||
|
|
||||||
|
try:
|
||||||
|
diagnostics: dict[str, object] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_path_result = await query_raw_with_schema("SHOW search_path")
|
||||||
|
diagnostics["search_path"] = search_path_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["search_path"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema_result = await query_raw_with_schema("SELECT current_schema()")
|
||||||
|
diagnostics["current_schema"] = schema_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["current_schema"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_result = await query_raw_with_schema(
|
||||||
|
"SELECT current_user, session_user, current_database()"
|
||||||
|
)
|
||||||
|
diagnostics["user_info"] = user_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["user_info"] = f"Error: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check pgvector extension installation (cluster-wide, stable info)
|
||||||
|
ext_result = await query_raw_with_schema(
|
||||||
|
"SELECT extname, extversion, nspname as schema "
|
||||||
|
"FROM pg_extension e "
|
||||||
|
"JOIN pg_namespace n ON e.extnamespace = n.oid "
|
||||||
|
"WHERE extname = 'vector'"
|
||||||
|
)
|
||||||
|
diagnostics["pgvector_extension"] = ext_result
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics["pgvector_extension"] = f"Error: {e}"
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Vector type error diagnostics:\n"
|
||||||
|
f" Error: {error}\n"
|
||||||
|
f" search_path: {diagnostics.get('search_path')}\n"
|
||||||
|
f" current_schema: {diagnostics.get('current_schema')}\n"
|
||||||
|
f" user_info: {diagnostics.get('user_info')}\n"
|
||||||
|
f" pgvector_extension: {diagnostics.get('pgvector_extension')}"
|
||||||
|
)
|
||||||
|
except Exception as diag_error:
|
||||||
|
logger.error(f"Failed to collect vector error diagnostics: {diag_error}")
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
# Backward compatibility alias - HybridSearchWeights maps to StoreAgentSearchWeights
|
||||||
# for existing code that expects the popularity parameter
|
# for existing code that expects the popularity parameter
|
||||||
HybridSearchWeights = StoreAgentSearchWeights
|
HybridSearchWeights = StoreAgentSearchWeights
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from backend.blocks.ideogram import (
|
|||||||
StyleType,
|
StyleType,
|
||||||
UpscaleOption,
|
UpscaleOption,
|
||||||
)
|
)
|
||||||
from backend.data.graph import BaseGraph
|
from backend.data.graph import GraphBaseMeta
|
||||||
from backend.data.model import CredentialsMetaInput, ProviderName
|
from backend.data.model import CredentialsMetaInput, ProviderName
|
||||||
from backend.integrations.credentials_store import ideogram_credentials
|
from backend.integrations.credentials_store import ideogram_credentials
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
@@ -34,14 +34,14 @@ class ImageStyle(str, Enum):
|
|||||||
DIGITAL_ART = "digital art"
|
DIGITAL_ART = "digital art"
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
if settings.config.use_agent_image_generation_v2:
|
if settings.config.use_agent_image_generation_v2:
|
||||||
return await generate_agent_image_v2(graph=agent)
|
return await generate_agent_image_v2(graph=agent)
|
||||||
else:
|
else:
|
||||||
return await generate_agent_image_v1(agent=agent)
|
return await generate_agent_image_v1(agent=agent)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Ideogram model.
|
Generate an image for an agent using Ideogram model.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -54,14 +54,17 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
description = f"{name} ({graph.description})" if graph.description else name
|
description = f"{name} ({graph.description})" if graph.description else name
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring "
|
"Create a visually striking retro-futuristic vector pop art illustration "
|
||||||
f'"{name}" in bold typography. The image clearly and literally depicts a {description}, '
|
f'prominently featuring "{name}" in bold typography. The image clearly and '
|
||||||
f"along with recognizable objects directly associated with the primary function of a {name}. "
|
f"literally depicts a {description}, along with recognizable objects directly "
|
||||||
f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the "
|
f"associated with the primary function of a {name}. "
|
||||||
f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric "
|
f"Ensure the imagery is concrete, intuitive, and immediately understandable, "
|
||||||
f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a "
|
f"clearly conveying the purpose of a {name}. "
|
||||||
f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, "
|
"Maintain vibrant, limited-palette colors, sharp vector lines, "
|
||||||
f"prioritizing clear visual storytelling and thematic clarity above all else."
|
"geometric shapes, flat illustration techniques, and solid colors "
|
||||||
|
"without gradients or shading. Preserve a retro-futuristic aesthetic "
|
||||||
|
"influenced by mid-century futurism and 1960s psychedelia, "
|
||||||
|
"prioritizing clear visual storytelling and thematic clarity above all else."
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_colors = [
|
custom_colors = [
|
||||||
@@ -99,12 +102,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
return io.BytesIO(response.content)
|
return io.BytesIO(response.content)
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO:
|
||||||
"""
|
"""
|
||||||
Generate an image for an agent using Flux model via Replicate API.
|
Generate an image for an agent using Flux model via Replicate API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent (Graph): The agent to generate an image for
|
agent (GraphBaseMeta | AgentGraph): The agent to generate an image for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
io.BytesIO: The generated image as bytes
|
io.BytesIO: The generated image as bytes
|
||||||
@@ -114,7 +117,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO:
|
|||||||
raise ValueError("Missing Replicate API key in settings")
|
raise ValueError("Missing Replicate API key in settings")
|
||||||
|
|
||||||
# Construct prompt from agent details
|
# Construct prompt from agent details
|
||||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
prompt = (
|
||||||
|
"Create a visually engaging app store thumbnail for the AI agent "
|
||||||
|
"that highlights what it does in a clear and captivating way:\n"
|
||||||
|
f"- **Name**: {agent.name}\n"
|
||||||
|
f"- **Description**: {agent.description}\n"
|
||||||
|
f"Focus on showcasing its core functionality with an appealing design."
|
||||||
|
)
|
||||||
|
|
||||||
# Set up Replicate client
|
# Set up Replicate client
|
||||||
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
client = ReplicateClient(api_token=settings.secrets.replicate_api_key)
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ async def get_agent(
|
|||||||
)
|
)
|
||||||
async def get_graph_meta_by_store_listing_version_id(
|
async def get_graph_meta_by_store_listing_version_id(
|
||||||
store_listing_version_id: str,
|
store_listing_version_id: str,
|
||||||
) -> backend.data.graph.GraphMeta:
|
) -> backend.data.graph.GraphModelWithoutNodes:
|
||||||
"""
|
"""
|
||||||
Get Agent Graph from Store Listing Version ID.
|
Get Agent Graph from Store Listing Version ID.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
from .library import model as library_model
|
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -823,18 +822,16 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
# Sanity check
|
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
# Determine new version
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
latest_version_number = max(g.version for g in existing_versions)
|
|
||||||
graph.version = latest_version_number + 1
|
|
||||||
|
|
||||||
|
graph.version = max(g.version for g in existing_versions) + 1
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -842,27 +839,23 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
# Keep the library agent up to date with the new active version
|
await library_db.update_library_agent_version_and_settings(
|
||||||
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
user_id, new_graph_version
|
||||||
|
)
|
||||||
# Handle activation of the new graph first to ensure continuity
|
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
# Ensure new version is the only active version
|
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
# Handle deactivation of the previously active version
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs # make type checker happy
|
assert new_graph_version_with_subgraphs
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -900,33 +893,15 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
await library_db.update_library_agent_version_and_settings(
|
||||||
|
user_id, new_active_graph
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
library = await library_db.update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await library_db.update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||||
|
title="Mock ElevenLabs API key",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
ElevenLabsCredentials = APIKeyCredentials
|
||||||
|
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||||
|
]
|
||||||
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
77
autogpt_platform/backend/backend/blocks/encoder_block.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Text encoding block for converting special characters to escape sequences."""
|
||||||
|
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderBlock(Block):
|
||||||
|
"""
|
||||||
|
Encodes a string by converting special characters into escape sequences.
|
||||||
|
|
||||||
|
This block is the inverse of TextDecoderBlock. It takes text containing
|
||||||
|
special characters (like newlines, tabs, etc.) and converts them into
|
||||||
|
their escape sequence representations (e.g., newline becomes \\n).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
"""Input schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
text: str = SchemaField(
|
||||||
|
description="A string containing special characters to be encoded",
|
||||||
|
placeholder="Your text with newlines and quotes to encode",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
"""Output schema for TextEncoderBlock."""
|
||||||
|
|
||||||
|
encoded_text: str = SchemaField(
|
||||||
|
description="The encoded text with special characters converted to escape sequences"
|
||||||
|
)
|
||||||
|
error: str = SchemaField(description="Error message if encoding fails")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
|
||||||
|
description="Encodes a string by converting special characters into escape sequences",
|
||||||
|
categories={BlockCategory.TEXT},
|
||||||
|
input_schema=TextEncoderBlock.Input,
|
||||||
|
output_schema=TextEncoderBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"text": """Hello
|
||||||
|
World!
|
||||||
|
This is a "quoted" string."""
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"encoded_text",
|
||||||
|
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
"""
|
||||||
|
Encode the input text by converting special characters to escape sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: The input containing the text to encode.
|
||||||
|
**kwargs: Additional keyword arguments (unused).
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The encoded text with escape sequences, or an error message if encoding fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
yield "encoded_text", encoded_text
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"Encoding error: {str(e)}"
|
||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = aexa.websets.get(id=input_data.external_id)
|
webset = await aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = aexa.websets.create(
|
webset = await aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.list(
|
response = await aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = aexa.websets.preview(params=payload)
|
sdk_preview = await aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.create(
|
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = aexa.websets.enrichments.get(
|
current_enrich = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.get(
|
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.create(
|
sdk_import = await aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.imports.list(
|
response = await aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
deleted_import = await aexa.websets.imports.delete(
|
||||||
|
import_id=input_data.import_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create mock iterator
|
# Create async iterator for list_all
|
||||||
mock_items = [mock_item1, mock_item2]
|
async def async_item_iterator(*args, **kwargs):
|
||||||
|
for item in [mock_item1, mock_item2]:
|
||||||
|
yield item
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = aexa.websets.items.get(
|
sdk_item = await aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = aexa.websets.items.delete(
|
deleted_item = await aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.update(
|
sdk_monitor = await aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,7 +522,9 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
deleted_monitor = await aexa.websets.monitors.delete(
|
||||||
|
monitor_id=input_data.monitor_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -579,7 +581,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.monitors.list(
|
response = await aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = aexa.websets.wait_until_idle(
|
final_webset = await aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = aexa.websets.searches.get(
|
current_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.get(
|
sdk_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = aexa.websets.searches.cancel(
|
canceled_search = await aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||||
|
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
# AI/ML API models
|
# AI/ML API models
|
||||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||||
@@ -270,6 +271,9 @@ MODEL_METADATA = {
|
|||||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||||
), # claude-4-sonnet-20250514
|
), # claude-4-sonnet-20250514
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||||
|
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||||
|
), # claude-opus-4-6
|
||||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||||
), # claude-opus-4-5-20251101
|
), # claude-opus-4-5-20251101
|
||||||
@@ -527,12 +531,12 @@ class LLMResponse(BaseModel):
|
|||||||
|
|
||||||
def convert_openai_tool_fmt_to_anthropic(
|
def convert_openai_tool_fmt_to_anthropic(
|
||||||
openai_tools: list[dict] | None = None,
|
openai_tools: list[dict] | None = None,
|
||||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||||
"""
|
"""
|
||||||
Convert OpenAI tool format to Anthropic tool format.
|
Convert OpenAI tool format to Anthropic tool format.
|
||||||
"""
|
"""
|
||||||
if not openai_tools or len(openai_tools) == 0:
|
if not openai_tools or len(openai_tools) == 0:
|
||||||
return anthropic.NOT_GIVEN
|
return anthropic.omit
|
||||||
|
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
for tool in openai_tools:
|
for tool in openai_tools:
|
||||||
@@ -592,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
):
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.NOT_GIVEN
|
return openai.omit
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,246 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
|
||||||
from moviepy.video.fx.Loop import Loop
|
|
||||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
||||||
|
|
||||||
from backend.data.block import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockOutput,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
)
|
|
||||||
from backend.data.execution import ExecutionContext
|
|
||||||
from backend.data.model import SchemaField
|
|
||||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
|
||||||
|
|
||||||
|
|
||||||
class MediaDurationBlock(Block):
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
media_in: MediaFileType = SchemaField(
|
|
||||||
description="Media input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
is_video: bool = SchemaField(
|
|
||||||
description="Whether the media is a video (True) or audio (False).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
duration: float = SchemaField(
|
|
||||||
description="Duration of the media file (in seconds)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
|
||||||
description="Block to get the duration of a media file.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=MediaDurationBlock.Input,
|
|
||||||
output_schema=MediaDurationBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
# 1) Store the input media locally
|
|
||||||
local_media_path = await store_media_file(
|
|
||||||
file=input_data.media_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
media_abspath = get_exec_file_path(
|
|
||||||
execution_context.graph_exec_id, local_media_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
if input_data.is_video:
|
|
||||||
clip = VideoFileClip(media_abspath)
|
|
||||||
else:
|
|
||||||
clip = AudioFileClip(media_abspath)
|
|
||||||
|
|
||||||
yield "duration", clip.duration
|
|
||||||
|
|
||||||
|
|
||||||
class LoopVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="The input video (can be a URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
|
||||||
duration: Optional[float] = SchemaField(
|
|
||||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
)
|
|
||||||
n_loops: Optional[int] = SchemaField(
|
|
||||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
|
||||||
default=None,
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: str = SchemaField(
|
|
||||||
description="Looped video returned either as a relative path or a data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
|
||||||
description="Block to loop a video to a given duration or number of repeats.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=LoopVideoBlock.Input,
|
|
||||||
output_schema=LoopVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the input video locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
|
||||||
|
|
||||||
# 2) Load the clip
|
|
||||||
clip = VideoFileClip(input_abspath)
|
|
||||||
|
|
||||||
# 3) Apply the loop effect
|
|
||||||
looped_clip = clip
|
|
||||||
if input_data.duration:
|
|
||||||
# Loop until we reach the specified duration
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
|
||||||
elif input_data.n_loops:
|
|
||||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
|
||||||
else:
|
|
||||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
|
||||||
|
|
||||||
assert isinstance(looped_clip, VideoFileClip)
|
|
||||||
|
|
||||||
# 4) Save the looped output
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
|
||||||
|
|
||||||
looped_clip = looped_clip.with_audio(clip.audio)
|
|
||||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
|
|
||||||
|
|
||||||
class AddAudioToVideoBlock(Block):
|
|
||||||
"""
|
|
||||||
Block that adds (attaches) an audio track to an existing video.
|
|
||||||
Optionally scale the volume of the new track.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
video_in: MediaFileType = SchemaField(
|
|
||||||
description="Video input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
audio_in: MediaFileType = SchemaField(
|
|
||||||
description="Audio input (URL, data URI, or local path)."
|
|
||||||
)
|
|
||||||
volume: float = SchemaField(
|
|
||||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
video_out: MediaFileType = SchemaField(
|
|
||||||
description="Final video (with attached audio), as a path or data URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
|
||||||
description="Block to attach an audio file to a video file using moviepy.",
|
|
||||||
categories={BlockCategory.MULTIMEDIA},
|
|
||||||
input_schema=AddAudioToVideoBlock.Input,
|
|
||||||
output_schema=AddAudioToVideoBlock.Output,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
execution_context: ExecutionContext,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
assert execution_context.graph_exec_id is not None
|
|
||||||
assert execution_context.node_exec_id is not None
|
|
||||||
graph_exec_id = execution_context.graph_exec_id
|
|
||||||
node_exec_id = execution_context.node_exec_id
|
|
||||||
|
|
||||||
# 1) Store the inputs locally
|
|
||||||
local_video_path = await store_media_file(
|
|
||||||
file=input_data.video_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
local_audio_path = await store_media_file(
|
|
||||||
file=input_data.audio_in,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_local_processing",
|
|
||||||
)
|
|
||||||
|
|
||||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
|
||||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
|
||||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
|
||||||
|
|
||||||
# 2) Load video + audio with moviepy
|
|
||||||
video_clip = VideoFileClip(video_abspath)
|
|
||||||
audio_clip = AudioFileClip(audio_abspath)
|
|
||||||
# Optionally scale volume
|
|
||||||
if input_data.volume != 1.0:
|
|
||||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
|
||||||
|
|
||||||
# 3) Attach the new audio track
|
|
||||||
final_clip = video_clip.with_audio(audio_clip)
|
|
||||||
|
|
||||||
# 4) Write to output file
|
|
||||||
output_filename = MediaFileType(
|
|
||||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
|
||||||
)
|
|
||||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
|
||||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
|
||||||
|
|
||||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
|
||||||
video_out = await store_media_file(
|
|
||||||
file=output_filename,
|
|
||||||
execution_context=execution_context,
|
|
||||||
return_format="for_block_output",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "video_out", video_out
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.encoder_block import TextEncoderBlock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_basic():
|
||||||
|
"""Test basic encoding of newlines and special characters."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == "Hello\\nWorld"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_multiple_escapes():
|
||||||
|
"""Test encoding of multiple escape sequences."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(
|
||||||
|
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
|
||||||
|
):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
assert "\\t" in result[0][1]
|
||||||
|
assert "\\r" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_unicode():
|
||||||
|
"""Test that unicode characters are handled correctly."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
# Unicode characters should be escaped as \uXXXX sequences
|
||||||
|
assert "\\n" in result[0][1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_empty_string():
|
||||||
|
"""Test encoding of an empty string."""
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "encoded_text"
|
||||||
|
assert result[0][1] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_encoder_error_handling():
|
||||||
|
"""Test that encoding errors are handled gracefully."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
block = TextEncoderBlock()
|
||||||
|
result = []
|
||||||
|
|
||||||
|
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
|
||||||
|
async for output in block.run(TextEncoderBlock.Input(text="test")):
|
||||||
|
result.append(output)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0][0] == "error"
|
||||||
|
assert "Mocked encoding error" in result[0][1]
|
||||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Video editing blocks for AutoGPT Platform.
|
||||||
|
|
||||||
|
This module provides blocks for:
|
||||||
|
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||||
|
- Clipping/trimming video segments
|
||||||
|
- Concatenating multiple videos
|
||||||
|
- Adding text overlays
|
||||||
|
- Adding AI-generated narration
|
||||||
|
- Getting media duration
|
||||||
|
- Looping videos
|
||||||
|
- Adding audio to videos
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
- yt-dlp: For video downloading
|
||||||
|
- moviepy: For video editing operations
|
||||||
|
- elevenlabs: For AI narration (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||||
|
from backend.blocks.video.clip import VideoClipBlock
|
||||||
|
from backend.blocks.video.concat import VideoConcatBlock
|
||||||
|
from backend.blocks.video.download import VideoDownloadBlock
|
||||||
|
from backend.blocks.video.duration import MediaDurationBlock
|
||||||
|
from backend.blocks.video.loop import LoopVideoBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
|
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AddAudioToVideoBlock",
|
||||||
|
"LoopVideoBlock",
|
||||||
|
"MediaDurationBlock",
|
||||||
|
"VideoClipBlock",
|
||||||
|
"VideoConcatBlock",
|
||||||
|
"VideoDownloadBlock",
|
||||||
|
"VideoNarrationBlock",
|
||||||
|
"VideoTextOverlayBlock",
|
||||||
|
]
|
||||||
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
131
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Shared utilities for video blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Known operation tags added by video blocks
|
||||||
|
_VIDEO_OPS = (
|
||||||
|
r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID
|
||||||
|
_BLOCK_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*"
|
||||||
|
r"_" + _VIDEO_OPS + r"_"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output)
|
||||||
|
_UUID_PREFIX_RE = re.compile(
|
||||||
|
r"^[a-zA-Z0-9_-]*"
|
||||||
|
r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
r"[a-zA-Z0-9_-]*_"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_source_name(input_path: str, max_length: int = 50) -> str:
|
||||||
|
"""Extract the original source filename by stripping block-generated prefixes.
|
||||||
|
|
||||||
|
Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate
|
||||||
|
when chaining video blocks, recovering the original human-readable name.
|
||||||
|
|
||||||
|
Safe for plain filenames (no UUID -> no stripping).
|
||||||
|
Falls back to "video" if everything is stripped.
|
||||||
|
"""
|
||||||
|
stem = Path(input_path).stem
|
||||||
|
|
||||||
|
# Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively
|
||||||
|
while _BLOCK_PREFIX_RE.match(stem):
|
||||||
|
stem = _BLOCK_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
# Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block)
|
||||||
|
if _UUID_PREFIX_RE.match(stem):
|
||||||
|
stem = _UUID_PREFIX_RE.sub("", stem, count=1)
|
||||||
|
|
||||||
|
if not stem:
|
||||||
|
return "video"
|
||||||
|
|
||||||
|
return stem[:max_length]
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||||
|
"""Get appropriate video and audio codecs based on output file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to the output file (used to determine extension)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_codec, audio_codec)
|
||||||
|
|
||||||
|
Codec mappings:
|
||||||
|
- .mp4: H.264 + AAC (universal compatibility)
|
||||||
|
- .webm: VP8 + Vorbis (web streaming)
|
||||||
|
- .mkv: H.264 + AAC (container supports many codecs)
|
||||||
|
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||||
|
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||||
|
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||||
|
"""
|
||||||
|
ext = os.path.splitext(output_path)[1].lower()
|
||||||
|
|
||||||
|
codec_map: dict[str, tuple[str, str]] = {
|
||||||
|
".mp4": ("libx264", "aac"),
|
||||||
|
".webm": ("libvpx", "libvorbis"),
|
||||||
|
".mkv": ("libx264", "aac"),
|
||||||
|
".mov": ("libx264", "aac"),
|
||||||
|
".m4v": ("libx264", "aac"),
|
||||||
|
".avi": ("mpeg4", "libmp3lame"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return codec_map.get(ext, ("libx264", "aac"))
|
||||||
|
|
||||||
|
|
||||||
|
def strip_chapters_inplace(video_path: str) -> None:
|
||||||
|
"""Strip chapter metadata from a media file in-place using ffmpeg.
|
||||||
|
|
||||||
|
MoviePy 2.x crashes with IndexError when parsing files with embedded
|
||||||
|
chapter metadata (https://github.com/Zulko/moviepy/issues/2419).
|
||||||
|
This strips chapters without re-encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Absolute path to the media file to strip chapters from.
|
||||||
|
"""
|
||||||
|
base, ext = os.path.splitext(video_path)
|
||||||
|
tmp_path = base + ".tmp" + ext
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
video_path,
|
||||||
|
"-map_chapters",
|
||||||
|
"-1",
|
||||||
|
"-codec",
|
||||||
|
"copy",
|
||||||
|
tmp_path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.warning(
|
||||||
|
"ffmpeg chapter strip failed (rc=%d): %s",
|
||||||
|
result.returncode,
|
||||||
|
result.stderr,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
os.replace(tmp_path, video_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning("ffmpeg not found; skipping chapter strip")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.unlink(tmp_path)
|
||||||
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
113
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class AddAudioToVideoBlock(Block):
|
||||||
|
"""Add (attach) an audio track to an existing video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Video input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
audio_in: MediaFileType = SchemaField(
|
||||||
|
description="Audio input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
volume: float = SchemaField(
|
||||||
|
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Final video (with attached audio), as a path or data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||||
|
description="Block to attach an audio file to a video file using moviepy.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=AddAudioToVideoBlock.Input,
|
||||||
|
output_schema=AddAudioToVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the inputs locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
local_audio_path = await store_media_file(
|
||||||
|
file=input_data.audio_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
video_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path)
|
||||||
|
|
||||||
|
# 2) Load video + audio with moviepy
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
strip_chapters_inplace(audio_abspath)
|
||||||
|
video_clip = None
|
||||||
|
audio_clip = None
|
||||||
|
final_clip = None
|
||||||
|
try:
|
||||||
|
video_clip = VideoFileClip(video_abspath)
|
||||||
|
audio_clip = AudioFileClip(audio_abspath)
|
||||||
|
# Optionally scale volume
|
||||||
|
if input_data.volume != 1.0:
|
||||||
|
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||||
|
|
||||||
|
# 3) Attach the new audio track
|
||||||
|
final_clip = video_clip.with_audio(audio_clip)
|
||||||
|
|
||||||
|
# 4) Write to output file
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
final_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if final_clip:
|
||||||
|
final_clip.close()
|
||||||
|
if audio_clip:
|
||||||
|
audio_clip.close()
|
||||||
|
if video_clip:
|
||||||
|
video_clip.close()
|
||||||
|
|
||||||
|
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""VideoClipBlock - Extract a segment from a video file."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoClipBlock(Block):
|
||||||
|
"""Extract a time segment from a video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||||
|
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Clipped video file (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Clip duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||||
|
description="Extract a time segment from a video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 10.0,
|
||||||
|
},
|
||||||
|
test_output=[("video_out", str), ("duration", float)],
|
||||||
|
test_mock={
|
||||||
|
"_clip_video": lambda *args: 10.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clip_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> float:
|
||||||
|
"""Extract a clip from a video. Extracted for testability."""
|
||||||
|
clip = None
|
||||||
|
subclip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
clip = VideoFileClip(video_abspath)
|
||||||
|
subclip = clip.subclipped(start_time, end_time)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
subclip.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
return subclip.duration
|
||||||
|
finally:
|
||||||
|
if subclip:
|
||||||
|
subclip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range
|
||||||
|
if input_data.end_time <= input_data.start_time:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_clip_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = self._clip_video(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to clip video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import concatenate_videoclips
|
||||||
|
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoConcatBlock(Block):
|
||||||
|
"""Merge multiple video clips into one continuous video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
videos: list[MediaFileType] = SchemaField(
|
||||||
|
description="List of video files to concatenate (in order)"
|
||||||
|
)
|
||||||
|
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||||
|
description="Transition between clips", default="none"
|
||||||
|
)
|
||||||
|
transition_duration: int = SchemaField(
|
||||||
|
description="Transition duration in seconds",
|
||||||
|
default=1,
|
||||||
|
ge=0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||||
|
description="Output format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Concatenated video file (path or data URI)"
|
||||||
|
)
|
||||||
|
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||||
|
description="Merge multiple video clips into one continuous video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"videos": ["/tmp/a.mp4", "/tmp/b.mp4"],
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_out", str),
|
||||||
|
("total_duration", float),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_concat_videos": lambda *args: 20.0,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concat_videos(
|
||||||
|
self,
|
||||||
|
video_abspaths: list[str],
|
||||||
|
output_abspath: str,
|
||||||
|
transition: str,
|
||||||
|
transition_duration: int,
|
||||||
|
) -> float:
|
||||||
|
"""Concatenate videos. Extracted for testability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total duration of the concatenated video.
|
||||||
|
"""
|
||||||
|
clips = []
|
||||||
|
faded_clips = []
|
||||||
|
final = None
|
||||||
|
try:
|
||||||
|
# Load clips
|
||||||
|
for v in video_abspaths:
|
||||||
|
strip_chapters_inplace(v)
|
||||||
|
clips.append(VideoFileClip(v))
|
||||||
|
|
||||||
|
# Validate transition_duration against shortest clip
|
||||||
|
if transition in {"crossfade", "fade_black"} and transition_duration > 0:
|
||||||
|
min_duration = min(c.duration for c in clips)
|
||||||
|
if transition_duration >= min_duration:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=(
|
||||||
|
f"transition_duration ({transition_duration}s) must be "
|
||||||
|
f"shorter than the shortest clip ({min_duration:.2f}s)"
|
||||||
|
),
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
if transition == "crossfade":
|
||||||
|
for i, clip in enumerate(clips):
|
||||||
|
effects = []
|
||||||
|
if i > 0:
|
||||||
|
effects.append(CrossFadeIn(transition_duration))
|
||||||
|
if i < len(clips) - 1:
|
||||||
|
effects.append(CrossFadeOut(transition_duration))
|
||||||
|
if effects:
|
||||||
|
clip = clip.with_effects(effects)
|
||||||
|
faded_clips.append(clip)
|
||||||
|
final = concatenate_videoclips(
|
||||||
|
faded_clips,
|
||||||
|
method="compose",
|
||||||
|
padding=-transition_duration,
|
||||||
|
)
|
||||||
|
elif transition == "fade_black":
|
||||||
|
for clip in clips:
|
||||||
|
faded = clip.with_effects(
|
||||||
|
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||||
|
)
|
||||||
|
faded_clips.append(faded)
|
||||||
|
final = concatenate_videoclips(faded_clips)
|
||||||
|
else:
|
||||||
|
final = concatenate_videoclips(clips)
|
||||||
|
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
return final.duration
|
||||||
|
finally:
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
for clip in faded_clips:
|
||||||
|
clip.close()
|
||||||
|
for clip in clips:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate minimum clips
|
||||||
|
if len(input_data.videos) < 2:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message="At least 2 videos are required for concatenation",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store all input videos locally
|
||||||
|
video_abspaths = []
|
||||||
|
for video in input_data.videos:
|
||||||
|
local_path = await self._store_input_video(execution_context, video)
|
||||||
|
video_abspaths.append(
|
||||||
|
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = (
|
||||||
|
extract_source_name(video_abspaths[0]) if video_abspaths else "video"
|
||||||
|
)
|
||||||
|
output_filename = MediaFileType(
|
||||||
|
f"{node_exec_id}_concat_{source}.{input_data.output_format}"
|
||||||
|
)
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = self._concat_videos(
|
||||||
|
video_abspaths,
|
||||||
|
output_abspath,
|
||||||
|
input_data.transition,
|
||||||
|
input_data.transition_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "total_duration", total_duration
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to concatenate videos: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
172
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yt_dlp
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from yt_dlp import _Params
|
||||||
|
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDownloadBlock(Block):
|
||||||
|
"""Download video from URL using yt-dlp."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
url: str = SchemaField(
|
||||||
|
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||||
|
placeholder="https://www.youtube.com/watch?v=...",
|
||||||
|
)
|
||||||
|
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||||
|
description="Video quality preference", default="720p"
|
||||||
|
)
|
||||||
|
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||||
|
description="Output video format", default="mp4", advanced=True
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_file: MediaFileType = SchemaField(
|
||||||
|
description="Downloaded video (path or data URI)"
|
||||||
|
)
|
||||||
|
duration: float = SchemaField(description="Video duration in seconds")
|
||||||
|
title: str = SchemaField(description="Video title from source")
|
||||||
|
source_url: str = SchemaField(description="Original source URL")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||||
|
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can sandbox yt-dlp and handle security implications
|
||||||
|
test_input={
|
||||||
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
|
"quality": "480p",
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
("video_file", str),
|
||||||
|
("duration", float),
|
||||||
|
("title", str),
|
||||||
|
("source_url", str),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_download_video": lambda *args: (
|
||||||
|
"video.mp4",
|
||||||
|
212.0,
|
||||||
|
"Test Video",
|
||||||
|
),
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_format_string(self, quality: str) -> str:
|
||||||
|
formats = {
|
||||||
|
"best": "bestvideo+bestaudio/best",
|
||||||
|
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||||
|
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||||
|
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||||
|
"audio_only": "bestaudio/best",
|
||||||
|
}
|
||||||
|
return formats.get(quality, formats["720p"])
|
||||||
|
|
||||||
|
def _download_video(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
quality: str,
|
||||||
|
output_format: str,
|
||||||
|
output_dir: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> tuple[str, float, str]:
|
||||||
|
"""Download video. Extracted for testability."""
|
||||||
|
output_template = os.path.join(
|
||||||
|
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
ydl_opts: "_Params" = {
|
||||||
|
"format": f"{self._get_format_string(quality)}/best",
|
||||||
|
"outtmpl": output_template,
|
||||||
|
"merge_output_format": output_format,
|
||||||
|
"quiet": True,
|
||||||
|
"no_warnings": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
info = ydl.extract_info(url, download=True)
|
||||||
|
video_path = ydl.prepare_filename(info)
|
||||||
|
|
||||||
|
# Handle format conversion in filename
|
||||||
|
if not video_path.endswith(f".{output_format}"):
|
||||||
|
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||||
|
|
||||||
|
# Return just the filename, not the full path
|
||||||
|
filename = os.path.basename(video_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
filename,
|
||||||
|
info.get("duration") or 0.0,
|
||||||
|
info.get("title") or "Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Get the exec file directory
|
||||||
|
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
filename, duration, title = self._download_video(
|
||||||
|
input_data.url,
|
||||||
|
input_data.quality,
|
||||||
|
input_data.output_format,
|
||||||
|
output_dir,
|
||||||
|
node_exec_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, MediaFileType(filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_file", video_out
|
||||||
|
yield "duration", duration
|
||||||
|
yield "title", title
|
||||||
|
yield "source_url", input_data.url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to download video: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
77
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""MediaDurationBlock - Get the duration of a media file."""
|
||||||
|
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class MediaDurationBlock(Block):
|
||||||
|
"""Get the duration of a media file (video or audio)."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
media_in: MediaFileType = SchemaField(
|
||||||
|
description="Media input (URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
is_video: bool = SchemaField(
|
||||||
|
description="Whether the media is a video (True) or audio (False).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
duration: float = SchemaField(
|
||||||
|
description="Duration of the media file (in seconds)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||||
|
description="Block to get the duration of a media file.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=MediaDurationBlock.Input,
|
||||||
|
output_schema=MediaDurationBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# 1) Store the input media locally
|
||||||
|
local_media_path = await store_media_file(
|
||||||
|
file=input_data.media_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
media_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_media_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Strip chapters to avoid MoviePy crash, then load the clip
|
||||||
|
strip_chapters_inplace(media_abspath)
|
||||||
|
clip = None
|
||||||
|
try:
|
||||||
|
if input_data.is_video:
|
||||||
|
clip = VideoFileClip(media_abspath)
|
||||||
|
else:
|
||||||
|
clip = AudioFileClip(media_abspath)
|
||||||
|
|
||||||
|
duration = clip.duration
|
||||||
|
finally:
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
yield "duration", duration
|
||||||
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
115
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from moviepy.video.fx.Loop import Loop
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class LoopVideoBlock(Block):
|
||||||
|
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="The input video (can be a URL, data URI, or local path)."
|
||||||
|
)
|
||||||
|
duration: Optional[float] = SchemaField(
|
||||||
|
description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=3600.0, # Max 1 hour to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
n_loops: Optional[int] = SchemaField(
|
||||||
|
description="Number of times to repeat the video. Either n_loops or duration must be provided.",
|
||||||
|
default=None,
|
||||||
|
ge=1,
|
||||||
|
le=10, # Max 10 loops to prevent disk exhaustion
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Looped video returned either as a relative path or a data URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||||
|
description="Block to loop a video to a given duration or number of repeats.",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=LoopVideoBlock.Input,
|
||||||
|
output_schema=LoopVideoBlock.Output,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
assert execution_context.node_exec_id is not None
|
||||||
|
graph_exec_id = execution_context.graph_exec_id
|
||||||
|
node_exec_id = execution_context.node_exec_id
|
||||||
|
|
||||||
|
# 1) Store the input video locally
|
||||||
|
local_video_path = await store_media_file(
|
||||||
|
file=input_data.video_in,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||||
|
|
||||||
|
# 2) Load the clip
|
||||||
|
strip_chapters_inplace(input_abspath)
|
||||||
|
clip = None
|
||||||
|
looped_clip = None
|
||||||
|
try:
|
||||||
|
clip = VideoFileClip(input_abspath)
|
||||||
|
|
||||||
|
# 3) Apply the loop effect
|
||||||
|
if input_data.duration:
|
||||||
|
# Loop until we reach the specified duration
|
||||||
|
looped_clip = clip.with_effects([Loop(duration=input_data.duration)])
|
||||||
|
elif input_data.n_loops:
|
||||||
|
looped_clip = clip.with_effects([Loop(n=input_data.n_loops)])
|
||||||
|
else:
|
||||||
|
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||||
|
|
||||||
|
assert isinstance(looped_clip, VideoFileClip)
|
||||||
|
|
||||||
|
# 4) Save the looped output
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||||
|
|
||||||
|
looped_clip = looped_clip.with_audio(clip.audio)
|
||||||
|
looped_clip.write_videofile(
|
||||||
|
output_abspath, codec="libx264", audio_codec="aac"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if looped_clip:
|
||||||
|
looped_clip.close()
|
||||||
|
if clip:
|
||||||
|
clip.close()
|
||||||
|
|
||||||
|
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||||
|
video_out = await store_media_file(
|
||||||
|
file=output_filename,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
267
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from elevenlabs import ElevenLabs
|
||||||
|
from moviepy import CompositeAudioClip
|
||||||
|
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.elevenlabs._auth import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
ElevenLabsCredentials,
|
||||||
|
ElevenLabsCredentialsInput,
|
||||||
|
)
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import CredentialsField, SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoNarrationBlock(Block):
|
||||||
|
"""Generate AI narration and add to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||||
|
description="ElevenLabs API key for voice synthesis"
|
||||||
|
)
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
script: str = SchemaField(description="Narration script text")
|
||||||
|
voice_id: str = SchemaField(
|
||||||
|
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||||
|
)
|
||||||
|
model_id: Literal[
|
||||||
|
"eleven_multilingual_v2",
|
||||||
|
"eleven_flash_v2_5",
|
||||||
|
"eleven_turbo_v2_5",
|
||||||
|
"eleven_turbo_v2",
|
||||||
|
] = SchemaField(
|
||||||
|
description="ElevenLabs TTS model",
|
||||||
|
default="eleven_multilingual_v2",
|
||||||
|
)
|
||||||
|
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||||
|
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||||
|
default="ducking",
|
||||||
|
)
|
||||||
|
narration_volume: float = SchemaField(
|
||||||
|
description="Narration volume (0.0 to 2.0)",
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
original_volume: float = SchemaField(
|
||||||
|
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with narration (path or data URI)"
|
||||||
|
)
|
||||||
|
audio_file: MediaFileType = SchemaField(
|
||||||
|
description="Generated audio file (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||||
|
description="Generate AI narration and add to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
test_input={
|
||||||
|
"video_in": "/tmp/test.mp4",
|
||||||
|
"script": "Hello world",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_output=[("video_out", str), ("audio_file", str)],
|
||||||
|
test_mock={
|
||||||
|
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||||
|
"_add_narration_to_video": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_narration_audio(
|
||||||
|
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate narration audio via ElevenLabs API."""
|
||||||
|
client = ElevenLabs(api_key=api_key)
|
||||||
|
audio_generator = client.text_to_speech.convert(
|
||||||
|
voice_id=voice_id,
|
||||||
|
text=script,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
# The SDK returns a generator, collect all chunks
|
||||||
|
return b"".join(audio_generator)
|
||||||
|
|
||||||
|
def _add_narration_to_video(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
audio_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
mix_mode: str,
|
||||||
|
narration_volume: float,
|
||||||
|
original_volume: float,
|
||||||
|
) -> None:
|
||||||
|
"""Add narration audio to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
narration_original = None
|
||||||
|
narration_scaled = None
|
||||||
|
original = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
narration_original = AudioFileClip(audio_abspath)
|
||||||
|
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||||
|
narration = narration_scaled
|
||||||
|
|
||||||
|
if mix_mode == "replace":
|
||||||
|
final_audio = narration
|
||||||
|
elif mix_mode == "mix":
|
||||||
|
if video.audio:
|
||||||
|
original = video.audio.with_volume_scaled(original_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
else: # ducking - apply stronger attenuation
|
||||||
|
if video.audio:
|
||||||
|
# Ducking uses a much lower volume for original audio
|
||||||
|
ducking_volume = original_volume * 0.3
|
||||||
|
original = video.audio.with_volume_scaled(ducking_volume)
|
||||||
|
final_audio = CompositeAudioClip([original, narration])
|
||||||
|
else:
|
||||||
|
final_audio = narration
|
||||||
|
|
||||||
|
final = video.with_audio(final_audio)
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if original:
|
||||||
|
original.close()
|
||||||
|
if narration_scaled:
|
||||||
|
narration_scaled.close()
|
||||||
|
if narration_original:
|
||||||
|
narration_original.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: ElevenLabsCredentials,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate narration audio via ElevenLabs
|
||||||
|
audio_content = self._generate_narration_audio(
|
||||||
|
credentials.api_key.get_secret_value(),
|
||||||
|
input_data.script,
|
||||||
|
input_data.voice_id,
|
||||||
|
input_data.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save audio to exec file path
|
||||||
|
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||||
|
audio_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, audio_filename
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||||
|
with open(audio_abspath, "wb") as f:
|
||||||
|
f.write(audio_content)
|
||||||
|
|
||||||
|
# Add narration to video
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_narration_to_video(
|
||||||
|
video_abspath,
|
||||||
|
audio_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.mix_mode,
|
||||||
|
input_data.narration_volume,
|
||||||
|
input_data.original_volume,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
audio_out = await self._store_output_video(
|
||||||
|
execution_context, audio_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
yield "audio_file", audio_out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add narration: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
231
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from moviepy import CompositeVideoClip, TextClip
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from backend.blocks.video._utils import (
|
||||||
|
extract_source_name,
|
||||||
|
get_video_codecs,
|
||||||
|
strip_chapters_inplace,
|
||||||
|
)
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
from backend.data.model import SchemaField
|
||||||
|
from backend.util.exceptions import BlockExecutionError
|
||||||
|
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTextOverlayBlock(Block):
|
||||||
|
"""Add text overlay/caption to video."""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
video_in: MediaFileType = SchemaField(
|
||||||
|
description="Input video (URL, data URI, or local path)"
|
||||||
|
)
|
||||||
|
text: str = SchemaField(description="Text to overlay on video")
|
||||||
|
position: Literal[
|
||||||
|
"top",
|
||||||
|
"center",
|
||||||
|
"bottom",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||||
|
start_time: float | None = SchemaField(
|
||||||
|
description="When to show text (seconds). None = entire video",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
end_time: float | None = SchemaField(
|
||||||
|
description="When to hide text (seconds). None = until end",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
font_size: int = SchemaField(
|
||||||
|
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||||
|
)
|
||||||
|
font_color: str = SchemaField(
|
||||||
|
description="Font color (hex or name)", default="white", advanced=True
|
||||||
|
)
|
||||||
|
bg_color: str | None = SchemaField(
|
||||||
|
description="Background color behind text (None for transparent)",
|
||||||
|
default=None,
|
||||||
|
advanced=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
video_out: MediaFileType = SchemaField(
|
||||||
|
description="Video with text overlay (path or data URI)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||||
|
description="Add text overlay/caption to video",
|
||||||
|
categories={BlockCategory.MULTIMEDIA},
|
||||||
|
input_schema=self.Input,
|
||||||
|
output_schema=self.Output,
|
||||||
|
disabled=True, # Disable until we can lockdown imagemagick security policy
|
||||||
|
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||||
|
test_output=[("video_out", str)],
|
||||||
|
test_mock={
|
||||||
|
"_add_text_overlay": lambda *args: None,
|
||||||
|
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||||
|
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_input_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store input video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _store_output_video(
|
||||||
|
self, execution_context: ExecutionContext, file: MediaFileType
|
||||||
|
) -> MediaFileType:
|
||||||
|
"""Store output video. Extracted for testability."""
|
||||||
|
return await store_media_file(
|
||||||
|
file=file,
|
||||||
|
execution_context=execution_context,
|
||||||
|
return_format="for_block_output",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_text_overlay(
|
||||||
|
self,
|
||||||
|
video_abspath: str,
|
||||||
|
output_abspath: str,
|
||||||
|
text: str,
|
||||||
|
position: str,
|
||||||
|
start_time: float | None,
|
||||||
|
end_time: float | None,
|
||||||
|
font_size: int,
|
||||||
|
font_color: str,
|
||||||
|
bg_color: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Add text overlay to video. Extracted for testability."""
|
||||||
|
video = None
|
||||||
|
final = None
|
||||||
|
txt_clip = None
|
||||||
|
try:
|
||||||
|
strip_chapters_inplace(video_abspath)
|
||||||
|
video = VideoFileClip(video_abspath)
|
||||||
|
|
||||||
|
txt_clip = TextClip(
|
||||||
|
text=text,
|
||||||
|
font_size=font_size,
|
||||||
|
color=font_color,
|
||||||
|
bg_color=bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position mapping
|
||||||
|
pos_map = {
|
||||||
|
"top": ("center", "top"),
|
||||||
|
"center": ("center", "center"),
|
||||||
|
"bottom": ("center", "bottom"),
|
||||||
|
"top-left": ("left", "top"),
|
||||||
|
"top-right": ("right", "top"),
|
||||||
|
"bottom-left": ("left", "bottom"),
|
||||||
|
"bottom-right": ("right", "bottom"),
|
||||||
|
}
|
||||||
|
|
||||||
|
txt_clip = txt_clip.with_position(pos_map[position])
|
||||||
|
|
||||||
|
# Set timing
|
||||||
|
start = start_time or 0
|
||||||
|
end = end_time or video.duration
|
||||||
|
duration = max(0, end - start)
|
||||||
|
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||||
|
|
||||||
|
final = CompositeVideoClip([video, txt_clip])
|
||||||
|
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||||
|
final.write_videofile(
|
||||||
|
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if txt_clip:
|
||||||
|
txt_clip.close()
|
||||||
|
if final:
|
||||||
|
final.close()
|
||||||
|
if video:
|
||||||
|
video.close()
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
node_exec_id: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
# Validate time range if both are provided
|
||||||
|
if (
|
||||||
|
input_data.start_time is not None
|
||||||
|
and input_data.end_time is not None
|
||||||
|
and input_data.end_time <= input_data.start_time
|
||||||
|
):
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert execution_context.graph_exec_id is not None
|
||||||
|
|
||||||
|
# Store the input video locally
|
||||||
|
local_video_path = await self._store_input_video(
|
||||||
|
execution_context, input_data.video_in
|
||||||
|
)
|
||||||
|
video_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, local_video_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build output path
|
||||||
|
source = extract_source_name(local_video_path)
|
||||||
|
output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4")
|
||||||
|
output_abspath = get_exec_file_path(
|
||||||
|
execution_context.graph_exec_id, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_text_overlay(
|
||||||
|
video_abspath,
|
||||||
|
output_abspath,
|
||||||
|
input_data.text,
|
||||||
|
input_data.position,
|
||||||
|
input_data.start_time,
|
||||||
|
input_data.end_time,
|
||||||
|
input_data.font_size,
|
||||||
|
input_data.font_color,
|
||||||
|
input_data.bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return as workspace path or data URI based on context
|
||||||
|
video_out = await self._store_output_video(
|
||||||
|
execution_context, output_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "video_out", video_out
|
||||||
|
|
||||||
|
except BlockExecutionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise BlockExecutionError(
|
||||||
|
message=f"Failed to add text overlay: {e}",
|
||||||
|
block_name=self.name,
|
||||||
|
block_id=str(self.id),
|
||||||
|
) from e
|
||||||
@@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
try:
|
||||||
yield "video_id", video_id
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
|
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
# Only yield after all operations succeed
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
yield "video_id", video_id
|
||||||
|
yield "transcript", transcript_text
|
||||||
yield "transcript", transcript_text
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|||||||
@@ -246,7 +246,9 @@ class BlockSchema(BaseModel):
|
|||||||
f"is not of type {CredentialsMetaInput.__name__}"
|
f"is not of type {CredentialsMetaInput.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
cls.get_field_schema(field_name), field_name
|
||||||
|
)
|
||||||
|
|
||||||
elif field_name in credentials_fields:
|
elif field_name in credentials_fields:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|||||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
|||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||||
|
from backend.blocks.video.narration import VideoNarrationBlock
|
||||||
from backend.data.block import Block, BlockCost, BlockCostType
|
from backend.data.block import Block, BlockCost, BlockCostType
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import (
|
||||||
aiml_api_credentials,
|
aiml_api_credentials,
|
||||||
anthropic_credentials,
|
anthropic_credentials,
|
||||||
apollo_credentials,
|
apollo_credentials,
|
||||||
did_credentials,
|
did_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
enrichlayer_credentials,
|
enrichlayer_credentials,
|
||||||
groq_credentials,
|
groq_credentials,
|
||||||
ideogram_credentials,
|
ideogram_credentials,
|
||||||
@@ -78,6 +80,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_OPUS: 21,
|
LlmModel.CLAUDE_4_OPUS: 21,
|
||||||
LlmModel.CLAUDE_4_SONNET: 5,
|
LlmModel.CLAUDE_4_SONNET: 5,
|
||||||
|
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
VideoNarrationBlock: [
|
||||||
|
BlockCost(
|
||||||
|
cost_amount=5, # ElevenLabs TTS cost
|
||||||
|
cost_filter={
|
||||||
|
"credentials": {
|
||||||
|
"id": elevenlabs_credentials.id,
|
||||||
|
"provider": elevenlabs_credentials.provider,
|
||||||
|
"type": elevenlabs_credentials.type,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
|
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
||||||
|
# in a different month than month1 (January). This fixes a timing bug
|
||||||
|
# where if the test runs in early February, 35 days ago would be January,
|
||||||
|
# matching the mocked month1 and preventing the refill from triggering.
|
||||||
|
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
||||||
|
await UserBalance.prisma().update(
|
||||||
|
where={"userId": DEFAULT_USER_ID},
|
||||||
|
data={"updatedAt": dec_previous_year},
|
||||||
|
)
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from multiprocessing import Manager
|
|
||||||
from queue import Empty
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -1200,12 +1199,16 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Queue for managing the execution of agents.
|
Thread-safe queue for managing node execution within a single graph execution.
|
||||||
This will be shared between different processes
|
|
||||||
|
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||||
|
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||||
|
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue = Manager().Queue()
|
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||||
|
self.queue: queue.Queue[T] = queue.Queue()
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1220,7 +1223,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except Empty:
|
except queue.Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for ExecutionQueue thread-safety."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_execution_queue_uses_stdlib_queue():
|
||||||
|
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
assert isinstance(q.queue, queue.Queue)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_operations():
|
||||||
|
"""Test add, get, empty, and get_or_none."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
|
||||||
|
assert q.empty() is True
|
||||||
|
assert q.get_or_none() is None
|
||||||
|
|
||||||
|
result = q.add("item1")
|
||||||
|
assert result == "item1"
|
||||||
|
assert q.empty() is False
|
||||||
|
|
||||||
|
item = q.get()
|
||||||
|
assert item == "item1"
|
||||||
|
assert q.empty() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safety():
|
||||||
|
"""Test concurrent access from multiple threads."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
results = []
|
||||||
|
num_items = 100
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
for i in range(num_items):
|
||||||
|
q.add(f"item_{i}")
|
||||||
|
|
||||||
|
def consumer():
|
||||||
|
count = 0
|
||||||
|
while count < num_items:
|
||||||
|
item = q.get_or_none()
|
||||||
|
if item is not None:
|
||||||
|
results.append(item)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
producer_thread = threading.Thread(target=producer)
|
||||||
|
consumer_thread = threading.Thread(target=consumer)
|
||||||
|
|
||||||
|
producer_thread.start()
|
||||||
|
consumer_thread.start()
|
||||||
|
|
||||||
|
producer_thread.join(timeout=5)
|
||||||
|
consumer_thread.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(results) == num_items
|
||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast
|
||||||
|
|
||||||
from prisma.enums import SubmissionStatus
|
from prisma.enums import SubmissionStatus
|
||||||
from prisma.models import (
|
from prisma.models import (
|
||||||
@@ -20,7 +20,7 @@ from prisma.types import (
|
|||||||
AgentNodeLinkCreateInput,
|
AgentNodeLinkCreateInput,
|
||||||
StoreListingVersionWhereInput,
|
StoreListingVersionWhereInput,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, BeforeValidator, Field, create_model
|
from pydantic import BaseModel, BeforeValidator, Field
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
@@ -30,7 +30,6 @@ from backend.data.db import prisma as db
|
|||||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
CredentialsField,
|
|
||||||
CredentialsFieldInfo,
|
CredentialsFieldInfo,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
is_credentials_field_name,
|
is_credentials_field_name,
|
||||||
@@ -45,7 +44,6 @@ from .block import (
|
|||||||
AnyBlockSchema,
|
AnyBlockSchema,
|
||||||
Block,
|
Block,
|
||||||
BlockInput,
|
BlockInput,
|
||||||
BlockSchema,
|
|
||||||
BlockType,
|
BlockType,
|
||||||
EmptySchema,
|
EmptySchema,
|
||||||
get_block,
|
get_block,
|
||||||
@@ -113,10 +111,12 @@ class Link(BaseDbModel):
|
|||||||
|
|
||||||
class Node(BaseDbModel):
|
class Node(BaseDbModel):
|
||||||
block_id: str
|
block_id: str
|
||||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
input_default: BlockInput = Field( # dict[input_name, default_value]
|
||||||
metadata: dict[str, Any] = {}
|
default_factory=dict
|
||||||
input_links: list[Link] = []
|
)
|
||||||
output_links: list[Link] = []
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
input_links: list[Link] = Field(default_factory=list)
|
||||||
|
output_links: list[Link] = Field(default_factory=list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def credentials_optional(self) -> bool:
|
def credentials_optional(self) -> bool:
|
||||||
@@ -221,18 +221,33 @@ class NodeModel(Node):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseGraph(BaseDbModel):
|
class GraphBaseMeta(BaseDbModel):
|
||||||
|
"""
|
||||||
|
Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields.
|
||||||
|
"""
|
||||||
|
|
||||||
version: int = 1
|
version: int = 1
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
recommended_schedule_cron: str | None = None
|
recommended_schedule_cron: str | None = None
|
||||||
nodes: list[Node] = []
|
|
||||||
links: list[Link] = []
|
|
||||||
forked_from_id: str | None = None
|
forked_from_id: str | None = None
|
||||||
forked_from_version: int | None = None
|
forked_from_version: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGraph(GraphBaseMeta):
|
||||||
|
"""
|
||||||
|
Graph with nodes, links, and computed I/O schema fields.
|
||||||
|
|
||||||
|
Used to represent sub-graphs within a `Graph`. Contains the full graph
|
||||||
|
structure including nodes and links, plus computed fields for schemas
|
||||||
|
and trigger info. Does NOT include user_id or created_at (see GraphModel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[Node] = Field(default_factory=list)
|
||||||
|
links: list[Link] = Field(default_factory=list)
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def input_schema(self) -> dict[str, Any]:
|
def input_schema(self) -> dict[str, Any]:
|
||||||
@@ -361,44 +376,79 @@ class GraphTriggerInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseGraph):
|
class Graph(BaseGraph):
|
||||||
sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs
|
"""Creatable graph model used in API create/update endpoints."""
|
||||||
|
|
||||||
|
sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMeta(GraphBaseMeta):
|
||||||
|
"""
|
||||||
|
Lightweight graph metadata model representing an existing graph from the database,
|
||||||
|
for use in listings and summaries.
|
||||||
|
|
||||||
|
Lacks `GraphModel`'s nodes, links, and expensive computed fields.
|
||||||
|
Use for list endpoints where full graph data is not needed and performance matters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str # type: ignore
|
||||||
|
version: int # type: ignore
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, graph: "AgentGraph") -> Self:
|
||||||
|
return cls(
|
||||||
|
id=graph.id,
|
||||||
|
version=graph.version,
|
||||||
|
is_active=graph.isActive,
|
||||||
|
name=graph.name or "",
|
||||||
|
description=graph.description or "",
|
||||||
|
instructions=graph.instructions,
|
||||||
|
recommended_schedule_cron=graph.recommendedScheduleCron,
|
||||||
|
forked_from_id=graph.forkedFromId,
|
||||||
|
forked_from_version=graph.forkedFromVersion,
|
||||||
|
user_id=graph.userId,
|
||||||
|
created_at=graph.createdAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphModel(Graph, GraphMeta):
|
||||||
|
"""
|
||||||
|
Full graph model representing an existing graph from the database.
|
||||||
|
|
||||||
|
This is the primary model for working with persisted graphs. Includes all
|
||||||
|
graph data (nodes, links, sub_graphs) plus user ownership and timestamps.
|
||||||
|
Provides computed fields (input_schema, output_schema, etc.) used during
|
||||||
|
set-up (frontend) and execution (backend).
|
||||||
|
|
||||||
|
Inherits from:
|
||||||
|
- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas
|
||||||
|
- `GraphMeta`: provides user_id, created_at for database records
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[NodeModel] = Field(default_factory=list) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def starting_nodes(self) -> list[NodeModel]:
|
||||||
|
outbound_nodes = {link.sink_id for link in self.links}
|
||||||
|
input_nodes = {
|
||||||
|
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.id not in outbound_nodes or node.id in input_nodes
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
||||||
|
return cast(NodeModel, super().webhook_input_node)
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
schema = self._credentials_input_schema.jsonschema()
|
|
||||||
|
|
||||||
# Determine which credential fields are required based on credentials_optional metadata
|
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
required_fields = []
|
|
||||||
|
|
||||||
# Build a map of node_id -> node for quick lookup
|
|
||||||
all_nodes = {node.id: node for node in self.nodes}
|
|
||||||
for sub_graph in self.sub_graphs:
|
|
||||||
for node in sub_graph.nodes:
|
|
||||||
all_nodes[node.id] = node
|
|
||||||
|
|
||||||
for field_key, (
|
|
||||||
_field_info,
|
|
||||||
node_field_pairs,
|
|
||||||
) in graph_credentials_inputs.items():
|
|
||||||
# A field is required if ANY node using it has credentials_optional=False
|
|
||||||
is_required = False
|
|
||||||
for node_id, _field_name in node_field_pairs:
|
|
||||||
node = all_nodes.get(node_id)
|
|
||||||
if node and not node.credentials_optional:
|
|
||||||
is_required = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_required:
|
|
||||||
required_fields.append(field_key)
|
|
||||||
|
|
||||||
schema["required"] = required_fields
|
|
||||||
return schema
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _credentials_input_schema(self) -> type[BlockSchema]:
|
|
||||||
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -406,8 +456,8 @@ class Graph(BaseGraph):
|
|||||||
|
|
||||||
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
# Warn if same-provider credentials inputs can't be combined (= bad UX)
|
||||||
graph_cred_fields = list(graph_credentials_inputs.values())
|
graph_cred_fields = list(graph_credentials_inputs.values())
|
||||||
for i, (field, keys) in enumerate(graph_cred_fields):
|
for i, (field, keys, _) in enumerate(graph_cred_fields):
|
||||||
for other_field, other_keys in list(graph_cred_fields)[i + 1 :]:
|
for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]:
|
||||||
if field.provider != other_field.provider:
|
if field.provider != other_field.provider:
|
||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
@@ -423,31 +473,78 @@ class Graph(BaseGraph):
|
|||||||
f"keys: {keys} <> {other_keys}."
|
f"keys: {keys} <> {other_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = {
|
# Build JSON schema directly to avoid expensive create_model + validation overhead
|
||||||
agg_field_key: (
|
properties = {}
|
||||||
CredentialsMetaInput[
|
required_fields = []
|
||||||
Literal[tuple(field_info.provider)], # type: ignore
|
|
||||||
Literal[tuple(field_info.supported_types)], # type: ignore
|
|
||||||
],
|
|
||||||
CredentialsField(
|
|
||||||
required_scopes=set(field_info.required_scopes or []),
|
|
||||||
discriminator=field_info.discriminator,
|
|
||||||
discriminator_mapping=field_info.discriminator_mapping,
|
|
||||||
discriminator_values=field_info.discriminator_values,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for agg_field_key, (field_info, _) in graph_credentials_inputs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return create_model(
|
for agg_field_key, (
|
||||||
self.name.replace(" ", "") + "CredentialsInputSchema",
|
field_info,
|
||||||
__base__=BlockSchema,
|
_,
|
||||||
**fields, # type: ignore
|
is_required,
|
||||||
)
|
) in graph_credentials_inputs.items():
|
||||||
|
providers = list(field_info.provider)
|
||||||
|
cred_types = list(field_info.supported_types)
|
||||||
|
|
||||||
|
field_schema: dict[str, Any] = {
|
||||||
|
"credentials_provider": providers,
|
||||||
|
"credentials_types": cred_types,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"title": "Id", "type": "string"},
|
||||||
|
"title": {
|
||||||
|
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||||
|
"default": None,
|
||||||
|
"title": "Title",
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"title": "Provider",
|
||||||
|
"type": "string",
|
||||||
|
**(
|
||||||
|
{"enum": providers}
|
||||||
|
if len(providers) > 1
|
||||||
|
else {"const": providers[0]}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"title": "Type",
|
||||||
|
"type": "string",
|
||||||
|
**(
|
||||||
|
{"enum": cred_types}
|
||||||
|
if len(cred_types) > 1
|
||||||
|
else {"const": cred_types[0]}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["id", "provider", "type"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add other (optional) field info items
|
||||||
|
field_schema.update(
|
||||||
|
field_info.model_dump(
|
||||||
|
by_alias=True,
|
||||||
|
exclude_defaults=True,
|
||||||
|
exclude={"provider", "supported_types"}, # already included above
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure field schema is well-formed
|
||||||
|
CredentialsMetaInput.validate_credentials_field_schema(
|
||||||
|
field_schema, agg_field_key
|
||||||
|
)
|
||||||
|
|
||||||
|
properties[agg_field_key] = field_schema
|
||||||
|
if is_required:
|
||||||
|
required_fields.append(agg_field_key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required_fields,
|
||||||
|
}
|
||||||
|
|
||||||
def aggregate_credentials_inputs(
|
def aggregate_credentials_inputs(
|
||||||
self,
|
self,
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]:
|
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
dict[aggregated_field_key, tuple(
|
dict[aggregated_field_key, tuple(
|
||||||
@@ -455,13 +552,19 @@ class Graph(BaseGraph):
|
|||||||
(now includes discriminator_values from matching nodes)
|
(now includes discriminator_values from matching nodes)
|
||||||
set[(node_id, field_name)]: Node credentials fields that are
|
set[(node_id, field_name)]: Node credentials fields that are
|
||||||
compatible with this aggregated field spec
|
compatible with this aggregated field spec
|
||||||
|
bool: True if the field is required (any node has credentials_optional=False)
|
||||||
)]
|
)]
|
||||||
"""
|
"""
|
||||||
# First collect all credential field data with input defaults
|
# First collect all credential field data with input defaults
|
||||||
node_credential_data = []
|
# Track (field_info, (node_id, field_name), is_required) for each credential field
|
||||||
|
node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = []
|
||||||
|
node_required_map: dict[str, bool] = {} # node_id -> is_required
|
||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
|
# Track if this node requires credentials (credentials_optional=False means required)
|
||||||
|
node_required_map[node.id] = not node.credentials_optional
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
field_info,
|
field_info,
|
||||||
@@ -485,37 +588,21 @@ class Graph(BaseGraph):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Combine credential field info (this will merge discriminator_values automatically)
|
# Combine credential field info (this will merge discriminator_values automatically)
|
||||||
return CredentialsFieldInfo.combine(*node_credential_data)
|
combined = CredentialsFieldInfo.combine(*node_credential_data)
|
||||||
|
|
||||||
|
# Add is_required flag to each aggregated field
|
||||||
class GraphModel(Graph):
|
# A field is required if ANY node using it has credentials_optional=False
|
||||||
user_id: str
|
return {
|
||||||
nodes: list[NodeModel] = [] # type: ignore
|
key: (
|
||||||
|
field_info,
|
||||||
created_at: datetime
|
node_field_pairs,
|
||||||
|
any(
|
||||||
@property
|
node_required_map.get(node_id, True)
|
||||||
def starting_nodes(self) -> list[NodeModel]:
|
for node_id, _ in node_field_pairs
|
||||||
outbound_nodes = {link.sink_id for link in self.links}
|
),
|
||||||
input_nodes = {
|
)
|
||||||
node.id for node in self.nodes if node.block.block_type == BlockType.INPUT
|
for key, (field_info, node_field_pairs) in combined.items()
|
||||||
}
|
}
|
||||||
return [
|
|
||||||
node
|
|
||||||
for node in self.nodes
|
|
||||||
if node.id not in outbound_nodes or node.id in input_nodes
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def webhook_input_node(self) -> NodeModel | None: # type: ignore
|
|
||||||
return cast(NodeModel, super().webhook_input_node)
|
|
||||||
|
|
||||||
def meta(self) -> "GraphMeta":
|
|
||||||
"""
|
|
||||||
Returns a GraphMeta object with metadata about the graph.
|
|
||||||
This is used to return metadata about the graph without exposing nodes and links.
|
|
||||||
"""
|
|
||||||
return GraphMeta.from_graph(self)
|
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -799,13 +886,14 @@ class GraphModel(Graph):
|
|||||||
if is_static_output_block(link.source_id):
|
if is_static_output_block(link.source_id):
|
||||||
link.is_static = True # Each value block output should be static.
|
link.is_static = True # Each value block output should be static.
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_db(
|
def from_db( # type: ignore[reportIncompatibleMethodOverride]
|
||||||
|
cls,
|
||||||
graph: AgentGraph,
|
graph: AgentGraph,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
sub_graphs: list[AgentGraph] | None = None,
|
sub_graphs: list[AgentGraph] | None = None,
|
||||||
) -> "GraphModel":
|
) -> Self:
|
||||||
return GraphModel(
|
return cls(
|
||||||
id=graph.id,
|
id=graph.id,
|
||||||
user_id=graph.userId if not for_export else "",
|
user_id=graph.userId if not for_export else "",
|
||||||
version=graph.version,
|
version=graph.version,
|
||||||
@@ -831,17 +919,28 @@ class GraphModel(Graph):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def hide_nodes(self) -> "GraphModelWithoutNodes":
|
||||||
|
"""
|
||||||
|
Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden
|
||||||
|
(excluded from serialization). They are still present in the model instance
|
||||||
|
so all computed fields (e.g. `credentials_input_schema`) still work.
|
||||||
|
"""
|
||||||
|
return GraphModelWithoutNodes.model_validate(self, from_attributes=True)
|
||||||
|
|
||||||
class GraphMeta(Graph):
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
# Easy work-around to prevent exposing nodes and links in the API response
|
class GraphModelWithoutNodes(GraphModel):
|
||||||
nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore
|
"""
|
||||||
links: list[Link] = Field(default=[], exclude=True)
|
GraphModel variant that excludes nodes, links, and sub-graphs from serialization.
|
||||||
|
|
||||||
@staticmethod
|
Used in contexts like the store where exposing internal graph structure
|
||||||
def from_graph(graph: GraphModel) -> "GraphMeta":
|
is not desired. Inherits all computed fields from GraphModel but marks
|
||||||
return GraphMeta(**graph.model_dump())
|
nodes and links as excluded from JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: list[NodeModel] = Field(default_factory=list, exclude=True)
|
||||||
|
links: list[Link] = Field(default_factory=list, exclude=True)
|
||||||
|
|
||||||
|
sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class GraphsPaginated(BaseModel):
|
class GraphsPaginated(BaseModel):
|
||||||
@@ -912,21 +1011,11 @@ async def list_graphs_paginated(
|
|||||||
where=where_clause,
|
where=where_clause,
|
||||||
distinct=["id"],
|
distinct=["id"],
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
include=AGENT_GRAPH_INCLUDE,
|
|
||||||
skip=offset,
|
skip=offset,
|
||||||
take=page_size,
|
take=page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_models: list[GraphMeta] = []
|
graph_models = [GraphMeta.from_db(graph) for graph in graphs]
|
||||||
for graph in graphs:
|
|
||||||
try:
|
|
||||||
graph_meta = GraphModel.from_db(graph).meta()
|
|
||||||
# Trigger serialization to validate that the graph is well formed
|
|
||||||
graph_meta.model_dump()
|
|
||||||
graph_models.append(graph_meta)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return GraphsPaginated(
|
return GraphsPaginated(
|
||||||
graphs=graph_models,
|
graphs=graph_models,
|
||||||
|
|||||||
@@ -163,7 +163,6 @@ class User(BaseModel):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from prisma.models import User as PrismaUser
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
from backend.data.block import BlockSchema
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -508,15 +507,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||||
return get_args(cls.model_fields["type"].annotation)
|
return get_args(cls.model_fields["type"].annotation)
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
def validate_credentials_field_schema(
|
||||||
|
field_schema: dict[str, Any], field_name: str
|
||||||
|
):
|
||||||
"""Validates the schema of a credentials input field"""
|
"""Validates the schema of a credentials input field"""
|
||||||
field_name = next(
|
|
||||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
|
||||||
)
|
|
||||||
field_schema = model.jsonschema()["properties"][field_name]
|
|
||||||
try:
|
try:
|
||||||
schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if "Field required [type=missing" not in str(e):
|
if "Field required [type=missing" not in str(e):
|
||||||
raise
|
raise
|
||||||
@@ -526,11 +523,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
f"{field_schema}"
|
f"{field_schema}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
providers = cls.allowed_providers()
|
providers = field_info.provider
|
||||||
if (
|
if (
|
||||||
providers is not None
|
providers is not None
|
||||||
and len(providers) > 1
|
and len(providers) > 1
|
||||||
and not schema_extra.discriminator
|
and not field_info.discriminator
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Multi-provider CredentialsField '{field_name}' "
|
f"Multi-provider CredentialsField '{field_name}' "
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase):
|
|||||||
class AsyncRabbitMQ(RabbitMQBase):
|
class AsyncRabbitMQ(RabbitMQBase):
|
||||||
"""Asynchronous RabbitMQ client"""
|
"""Asynchronous RabbitMQ client"""
|
||||||
|
|
||||||
|
def __init__(self, config: RabbitMQConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self._reconnect_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return bool(self._connection and not self._connection.is_closed)
|
return bool(self._connection and not self._connection.is_closed)
|
||||||
@@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
|
|
||||||
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
@conn_retry("AsyncRabbitMQ", "Acquiring async connection")
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
if self.is_connected:
|
if self.is_connected and self._channel and not self._channel.is_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.is_connected
|
||||||
|
and self._connection
|
||||||
|
and (self._channel is None or self._channel.is_closed)
|
||||||
|
):
|
||||||
|
self._channel = await self._connection.channel()
|
||||||
|
await self._channel.set_qos(prefetch_count=1)
|
||||||
|
await self.declare_infrastructure()
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection = await aio_pika.connect_robust(
|
self._connection = await aio_pika.connect_robust(
|
||||||
@@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
exchange, routing_key=queue.routing_key or queue.name
|
exchange, routing_key=queue.routing_key or queue.name
|
||||||
)
|
)
|
||||||
|
|
||||||
@func_retry
|
@property
|
||||||
async def publish_message(
|
def _lock(self) -> asyncio.Lock:
|
||||||
|
if self._reconnect_lock is None:
|
||||||
|
self._reconnect_lock = asyncio.Lock()
|
||||||
|
return self._reconnect_lock
|
||||||
|
|
||||||
|
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
|
"""Get a valid channel, reconnecting if the current one is stale.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent reconnection attempts from racing.
|
||||||
|
"""
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore # is_ready guarantees non-None
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if self.is_ready:
|
||||||
|
return self._channel # type: ignore
|
||||||
|
|
||||||
|
self._channel = None
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
if self._channel is None:
|
||||||
|
raise RuntimeError("Channel should be established after connect")
|
||||||
|
|
||||||
|
return self._channel
|
||||||
|
|
||||||
|
async def _publish_once(
|
||||||
self,
|
self,
|
||||||
routing_key: str,
|
routing_key: str,
|
||||||
message: str,
|
message: str,
|
||||||
exchange: Optional[Exchange] = None,
|
exchange: Optional[Exchange] = None,
|
||||||
persistent: bool = True,
|
persistent: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.is_ready:
|
channel = await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
|
|
||||||
if exchange:
|
if exchange:
|
||||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
exchange_obj = await channel.get_exchange(exchange.name)
|
||||||
else:
|
else:
|
||||||
exchange_obj = self._channel.default_exchange
|
exchange_obj = channel.default_exchange
|
||||||
|
|
||||||
await exchange_obj.publish(
|
await exchange_obj.publish(
|
||||||
aio_pika.Message(
|
aio_pika.Message(
|
||||||
@@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase):
|
|||||||
routing_key=routing_key,
|
routing_key=routing_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@func_retry
|
||||||
|
async def publish_message(
|
||||||
|
self,
|
||||||
|
routing_key: str,
|
||||||
|
message: str,
|
||||||
|
exchange: Optional[Exchange] = None,
|
||||||
|
persistent: bool = True,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||||
|
logger.warning(
|
||||||
|
"RabbitMQ channel invalid, forcing reconnect and retrying publish"
|
||||||
|
)
|
||||||
|
async with self._lock:
|
||||||
|
self._channel = None
|
||||||
|
await self._publish_once(routing_key, message, exchange, persistent)
|
||||||
|
|
||||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||||
if not self.is_ready:
|
return await self._ensure_channel()
|
||||||
await self.connect()
|
|
||||||
if self._channel is None:
|
|
||||||
raise RuntimeError("Channel should be established after connect")
|
|
||||||
return self._channel
|
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ def make_node_credentials_input_map(
|
|||||||
# Get aggregated credentials fields for the graph
|
# Get aggregated credentials fields for the graph
|
||||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
if graph_input_name not in graph_credentials_input:
|
if graph_input_name not in graph_credentials_input:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
|||||||
expires_at=None,
|
expires_at=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elevenlabs_credentials = APIKeyCredentials(
|
||||||
|
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||||
|
provider="elevenlabs",
|
||||||
|
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||||
|
title="Use Credits for ElevenLabs",
|
||||||
|
expires_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_CREDENTIALS = [
|
DEFAULT_CREDENTIALS = [
|
||||||
ollama_credentials,
|
ollama_credentials,
|
||||||
revid_credentials,
|
revid_credentials,
|
||||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
|||||||
v0_credentials,
|
v0_credentials,
|
||||||
webshare_proxy_credentials,
|
webshare_proxy_credentials,
|
||||||
openweathermap_credentials,
|
openweathermap_credentials,
|
||||||
|
elevenlabs_credentials,
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
|||||||
all_credentials.append(webshare_proxy_credentials)
|
all_credentials.append(webshare_proxy_credentials)
|
||||||
if settings.secrets.openweathermap_api_key:
|
if settings.secrets.openweathermap_api_key:
|
||||||
all_credentials.append(openweathermap_credentials)
|
all_credentials.append(openweathermap_credentials)
|
||||||
|
if settings.secrets.elevenlabs_api_key:
|
||||||
|
all_credentials.append(elevenlabs_credentials)
|
||||||
return all_credentials
|
return all_credentials
|
||||||
|
|
||||||
async def get_creds_by_id(
|
async def get_creds_by_id(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
|||||||
DISCORD = "discord"
|
DISCORD = "discord"
|
||||||
D_ID = "d_id"
|
D_ID = "d_id"
|
||||||
E2B = "e2b"
|
E2B = "e2b"
|
||||||
|
ELEVENLABS = "elevenlabs"
|
||||||
FAL = "fal"
|
FAL = "fal"
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
@@ -17,6 +19,35 @@ from backend.util.virus_scanner import scan_content_safe
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceUri(BaseModel):
|
||||||
|
"""Parsed workspace:// URI."""
|
||||||
|
|
||||||
|
file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt")
|
||||||
|
mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4")
|
||||||
|
is_path: bool = False # True if file_ref is a path (starts with "/")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_workspace_uri(uri: str) -> WorkspaceUri:
|
||||||
|
"""Parse a workspace:// URI into its components.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False)
|
||||||
|
"workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False)
|
||||||
|
"workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True)
|
||||||
|
"""
|
||||||
|
raw = uri.removeprefix("workspace://")
|
||||||
|
mime_type: str | None = None
|
||||||
|
if "#" in raw:
|
||||||
|
raw, fragment = raw.split("#", 1)
|
||||||
|
mime_type = fragment or None
|
||||||
|
return WorkspaceUri(
|
||||||
|
file_ref=raw,
|
||||||
|
mime_type=mime_type,
|
||||||
|
is_path=raw.startswith("/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Return format options for store_media_file
|
# Return format options for store_media_file
|
||||||
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
# - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc.
|
||||||
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
# - "for_external_api": Returns data URI (base64) - use when sending content to external APIs
|
||||||
@@ -183,22 +214,20 @@ async def store_media_file(
|
|||||||
"This file type is only available in CoPilot sessions."
|
"This file type is only available in CoPilot sessions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse workspace reference
|
# Parse workspace reference (strips #mimeType fragment from file ID)
|
||||||
# workspace://abc123 - by file ID
|
ws = parse_workspace_uri(file)
|
||||||
# workspace:///path/to/file.txt - by virtual path
|
|
||||||
file_ref = file[12:] # Remove "workspace://"
|
|
||||||
|
|
||||||
if file_ref.startswith("/"):
|
if ws.is_path:
|
||||||
# Path reference
|
# Path reference: workspace:///path/to/file.txt
|
||||||
workspace_content = await workspace_manager.read_file(file_ref)
|
workspace_content = await workspace_manager.read_file(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info_by_path(file_ref)
|
file_info = await workspace_manager.get_file_info_by_path(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# ID reference
|
# ID reference: workspace://abc123 or workspace://abc123#video/mp4
|
||||||
workspace_content = await workspace_manager.read_file_by_id(file_ref)
|
workspace_content = await workspace_manager.read_file_by_id(ws.file_ref)
|
||||||
file_info = await workspace_manager.get_file_info(file_ref)
|
file_info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
filename = sanitize_filename(
|
filename = sanitize_filename(
|
||||||
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
file_info.name if file_info else f"{uuid.uuid4()}.bin"
|
||||||
)
|
)
|
||||||
@@ -313,6 +342,14 @@ async def store_media_file(
|
|||||||
if not target_path.is_file():
|
if not target_path.is_file():
|
||||||
raise ValueError(f"Local file does not exist: {target_path}")
|
raise ValueError(f"Local file does not exist: {target_path}")
|
||||||
|
|
||||||
|
# Virus scan the local file before any further processing
|
||||||
|
local_content = target_path.read_bytes()
|
||||||
|
if len(local_content) > MAX_FILE_SIZE_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {len(local_content)} bytes > {MAX_FILE_SIZE_BYTES} bytes"
|
||||||
|
)
|
||||||
|
await scan_content_safe(local_content, filename=sanitized_file)
|
||||||
|
|
||||||
# Return based on requested format
|
# Return based on requested format
|
||||||
if return_format == "for_local_processing":
|
if return_format == "for_local_processing":
|
||||||
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
# Use when processing files locally with tools like ffmpeg, MoviePy, PIL
|
||||||
@@ -334,7 +371,21 @@ async def store_media_file(
|
|||||||
|
|
||||||
# Don't re-save if input was already from workspace
|
# Don't re-save if input was already from workspace
|
||||||
if is_from_workspace:
|
if is_from_workspace:
|
||||||
# Return original workspace reference
|
# Return original workspace reference, ensuring MIME type fragment
|
||||||
|
ws = parse_workspace_uri(file)
|
||||||
|
if not ws.mime_type:
|
||||||
|
# Add MIME type fragment if missing (older refs without it)
|
||||||
|
try:
|
||||||
|
if ws.is_path:
|
||||||
|
info = await workspace_manager.get_file_info_by_path(
|
||||||
|
ws.file_ref
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
|
if info:
|
||||||
|
return MediaFileType(f"{file}#{info.mimeType}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
|
|
||||||
# Save new content to workspace
|
# Save new content to workspace
|
||||||
@@ -346,7 +397,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}")
|
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -247,3 +247,100 @@ class TestFileCloudIntegration:
|
|||||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
return_format="for_local_processing",
|
return_format="for_local_processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_media_file_local_path_scanned(self):
|
||||||
|
"""Test that local file paths are scanned for viruses."""
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
local_file = "test_video.mp4"
|
||||||
|
file_content = b"fake video content"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.file.get_cloud_storage_handler"
|
||||||
|
) as mock_handler_getter, patch(
|
||||||
|
"backend.util.file.scan_content_safe"
|
||||||
|
) as mock_scan, patch(
|
||||||
|
"backend.util.file.Path"
|
||||||
|
) as mock_path_class:
|
||||||
|
|
||||||
|
# Mock cloud storage handler - not a cloud path
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
mock_handler.is_cloud_path.return_value = False
|
||||||
|
mock_handler_getter.return_value = mock_handler
|
||||||
|
|
||||||
|
# Mock virus scanner
|
||||||
|
mock_scan.return_value = None
|
||||||
|
|
||||||
|
# Mock file system operations
|
||||||
|
mock_base_path = MagicMock()
|
||||||
|
mock_target_path = MagicMock()
|
||||||
|
mock_resolved_path = MagicMock()
|
||||||
|
|
||||||
|
mock_path_class.return_value = mock_base_path
|
||||||
|
mock_base_path.mkdir = MagicMock()
|
||||||
|
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||||
|
mock_target_path.resolve.return_value = mock_resolved_path
|
||||||
|
mock_resolved_path.is_relative_to.return_value = True
|
||||||
|
mock_resolved_path.is_file.return_value = True
|
||||||
|
mock_resolved_path.read_bytes.return_value = file_content
|
||||||
|
mock_resolved_path.relative_to.return_value = Path(local_file)
|
||||||
|
mock_resolved_path.name = local_file
|
||||||
|
|
||||||
|
result = await store_media_file(
|
||||||
|
file=MediaFileType(local_file),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify virus scan was called for local file
|
||||||
|
mock_scan.assert_called_once_with(file_content, filename=local_file)
|
||||||
|
|
||||||
|
# Result should be the relative path
|
||||||
|
assert str(result) == local_file
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_media_file_local_path_virus_detected(self):
|
||||||
|
"""Test that infected local files raise VirusDetectedError."""
|
||||||
|
from backend.api.features.store.exceptions import VirusDetectedError
|
||||||
|
|
||||||
|
graph_exec_id = "test-exec-123"
|
||||||
|
local_file = "infected.exe"
|
||||||
|
file_content = b"malicious content"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.util.file.get_cloud_storage_handler"
|
||||||
|
) as mock_handler_getter, patch(
|
||||||
|
"backend.util.file.scan_content_safe"
|
||||||
|
) as mock_scan, patch(
|
||||||
|
"backend.util.file.Path"
|
||||||
|
) as mock_path_class:
|
||||||
|
|
||||||
|
# Mock cloud storage handler - not a cloud path
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
mock_handler.is_cloud_path.return_value = False
|
||||||
|
mock_handler_getter.return_value = mock_handler
|
||||||
|
|
||||||
|
# Mock virus scanner to detect virus
|
||||||
|
mock_scan.side_effect = VirusDetectedError(
|
||||||
|
"EICAR-Test-File", "File rejected due to virus detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock file system operations
|
||||||
|
mock_base_path = MagicMock()
|
||||||
|
mock_target_path = MagicMock()
|
||||||
|
mock_resolved_path = MagicMock()
|
||||||
|
|
||||||
|
mock_path_class.return_value = mock_base_path
|
||||||
|
mock_base_path.mkdir = MagicMock()
|
||||||
|
mock_base_path.__truediv__ = MagicMock(return_value=mock_target_path)
|
||||||
|
mock_target_path.resolve.return_value = mock_resolved_path
|
||||||
|
mock_resolved_path.is_relative_to.return_value = True
|
||||||
|
mock_resolved_path.is_file.return_value = True
|
||||||
|
mock_resolved_path.read_bytes.return_value = file_content
|
||||||
|
|
||||||
|
with pytest.raises(VirusDetectedError):
|
||||||
|
await store_media_file(
|
||||||
|
file=MediaFileType(local_file),
|
||||||
|
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||||
|
return_format="for_local_processing",
|
||||||
|
)
|
||||||
|
|||||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
|
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||||
|
|
||||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from backend.data.workspace import (
|
|||||||
soft_delete_workspace_file,
|
soft_delete_workspace_file,
|
||||||
)
|
)
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -187,6 +188,9 @@ class WorkspaceManager:
|
|||||||
f"{Config().max_file_size_mb}MB limit"
|
f"{Config().max_file_size_mb}MB limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Virus scan content before persisting (defense in depth)
|
||||||
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
# Determine path with session scoping
|
# Determine path with session scoping
|
||||||
if path is None:
|
if path is None:
|
||||||
path = f"/{filename}"
|
path = f"/{filename}"
|
||||||
|
|||||||
7148
autogpt_platform/backend/poetry.lock
generated
7148
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -12,15 +12,16 @@ python = ">=3.10,<3.14"
|
|||||||
aio-pika = "^9.5.5"
|
aio-pika = "^9.5.5"
|
||||||
aiohttp = "^3.10.0"
|
aiohttp = "^3.10.0"
|
||||||
aiodns = "^3.5.0"
|
aiodns = "^3.5.0"
|
||||||
anthropic = "^0.59.0"
|
anthropic = "^0.79.0"
|
||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^45.0"
|
cryptography = "^46.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
fastapi = "^0.116.1"
|
elevenlabs = "^1.50.0"
|
||||||
|
fastapi = "^0.128.6"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -33,11 +34,11 @@ html2text = "^2024.2.26"
|
|||||||
jinja2 = "^3.1.6"
|
jinja2 = "^3.1.6"
|
||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.14.1"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.6.1"
|
||||||
openai = "^1.97.1"
|
openai = "^1.97.1"
|
||||||
orjson = "^3.10.0"
|
orjson = "^3.10.0"
|
||||||
pika = "^1.3.2"
|
pika = "^1.3.2"
|
||||||
@@ -47,16 +48,16 @@ postmarker = "^1.0"
|
|||||||
praw = "~7.8.1"
|
praw = "~7.8.1"
|
||||||
prisma = "^0.15.0"
|
prisma = "^0.15.0"
|
||||||
rank-bm25 = "^0.2.2"
|
rank-bm25 = "^0.2.2"
|
||||||
prometheus-client = "^0.22.1"
|
prometheus-client = "^0.24.1"
|
||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
python-multipart = "^0.0.20"
|
python-multipart = "^0.0.22"
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
regex = "^2025.9.18"
|
regex = "^2025.9.18"
|
||||||
replicate = "^1.0.6"
|
replicate = "^1.0.6"
|
||||||
@@ -64,18 +65,19 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.17.0"
|
supabase = "2.27.3"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.4"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
|
yt-dlp = "2025.12.08"
|
||||||
zerobouncesdk = "^1.1.2"
|
zerobouncesdk = "^1.1.2"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
pytest-snapshot = "^0.9.0"
|
pytest-snapshot = "^0.9.0"
|
||||||
aiofiles = "^24.1.0"
|
aiofiles = "^24.1.0"
|
||||||
tiktoken = "^0.9.0"
|
tiktoken = "^0.12.0"
|
||||||
aioclamd = "^1.0.0"
|
aioclamd = "^1.0.0"
|
||||||
setuptools = "^80.9.0"
|
setuptools = "^80.9.0"
|
||||||
gcloud-aio-storage = "^9.5.0"
|
gcloud-aio-storage = "^9.5.0"
|
||||||
@@ -93,13 +95,13 @@ black = "^24.10.0"
|
|||||||
faker = "^38.2.0"
|
faker = "^38.2.0"
|
||||||
httpx = "^0.28.1"
|
httpx = "^0.28.1"
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
poethepoet = "^0.37.0"
|
poethepoet = "^0.41.0"
|
||||||
pre-commit = "^4.4.0"
|
pre-commit = "^4.4.0"
|
||||||
pyright = "^1.1.407"
|
pyright = "^1.1.407"
|
||||||
pytest-mock = "^3.15.1"
|
pytest-mock = "^3.15.1"
|
||||||
pytest-watcher = "^0.4.2"
|
pytest-watcher = "^0.6.3"
|
||||||
requests = "^2.32.5"
|
requests = "^2.32.5"
|
||||||
ruff = "^0.14.5"
|
ruff = "^0.15.0"
|
||||||
# NOTE: please insert new dependencies in their alphabetical location
|
# NOTE: please insert new dependencies in their alphabetical location
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
"credentials_input_schema": {
|
"credentials_input_schema": {
|
||||||
"properties": {},
|
"properties": {},
|
||||||
"required": [],
|
"required": [],
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
|
||||||
"type": "object"
|
"type": "object"
|
||||||
},
|
},
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
|
|||||||
@@ -1,34 +1,14 @@
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"credentials_input_schema": {
|
"created_at": "2025-09-04T13:37:00",
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"title": "TestGraphCredentialsInputSchema",
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"description": "A test graph",
|
"description": "A test graph",
|
||||||
"forked_from_id": null,
|
"forked_from_id": null,
|
||||||
"forked_from_version": null,
|
"forked_from_version": null,
|
||||||
"has_external_trigger": false,
|
|
||||||
"has_human_in_the_loop": false,
|
|
||||||
"has_sensitive_action": false,
|
|
||||||
"id": "graph-123",
|
"id": "graph-123",
|
||||||
"input_schema": {
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"instructions": null,
|
"instructions": null,
|
||||||
"is_active": true,
|
"is_active": true,
|
||||||
"name": "Test Graph",
|
"name": "Test Graph",
|
||||||
"output_schema": {
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
"recommended_schedule_cron": null,
|
"recommended_schedule_cron": null,
|
||||||
"sub_graphs": [],
|
|
||||||
"trigger_setup_info": null,
|
|
||||||
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
"user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||||
"version": 1
|
"version": 1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,8 +25,12 @@ RUN if [ -f .env.production ]; then \
|
|||||||
cp .env.default .env; \
|
cp .env.default .env; \
|
||||||
fi
|
fi
|
||||||
RUN pnpm run generate:api
|
RUN pnpm run generate:api
|
||||||
|
# Disable source-map generation in Docker builds to halve webpack memory usage.
|
||||||
|
# Source maps are only useful when SENTRY_AUTH_TOKEN is set (Vercel deploys);
|
||||||
|
# the Docker image never uploads them, so generating them just wastes RAM.
|
||||||
|
ENV NEXT_PUBLIC_SOURCEMAPS="false"
|
||||||
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
# In CI, we want NEXT_PUBLIC_PW_TEST=true during build so Next.js inlines it
|
||||||
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=4096" pnpm build; else NODE_OPTIONS="--max-old-space-size=4096" pnpm build; fi
|
RUN if [ "$NEXT_PUBLIC_PW_TEST" = "true" ]; then NEXT_PUBLIC_PW_TEST=true NODE_OPTIONS="--max-old-space-size=8192" pnpm build; else NODE_OPTIONS="--max-old-space-size=8192" pnpm build; fi
|
||||||
|
|
||||||
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
# Prod stage - based on NextJS reference Dockerfile https://github.com/vercel/next.js/blob/64271354533ed16da51be5dce85f0dbd15f17517/examples/with-docker/Dockerfile
|
||||||
FROM node:21-alpine AS prod
|
FROM node:21-alpine AS prod
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
import { withSentryConfig } from "@sentry/nextjs";
|
import { withSentryConfig } from "@sentry/nextjs";
|
||||||
|
|
||||||
|
// Allow Docker builds to skip source-map generation (halves memory usage).
|
||||||
|
// Defaults to true so Vercel/local builds are unaffected.
|
||||||
|
const enableSourceMaps = process.env.NEXT_PUBLIC_SOURCEMAPS !== "false";
|
||||||
|
|
||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {
|
const nextConfig = {
|
||||||
productionBrowserSourceMaps: true,
|
productionBrowserSourceMaps: enableSourceMaps,
|
||||||
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
// Externalize OpenTelemetry packages to fix Turbopack HMR issues
|
||||||
serverExternalPackages: [
|
serverExternalPackages: [
|
||||||
"@opentelemetry/instrumentation",
|
"@opentelemetry/instrumentation",
|
||||||
@@ -14,9 +18,37 @@ const nextConfig = {
|
|||||||
serverActions: {
|
serverActions: {
|
||||||
bodySizeLimit: "256mb",
|
bodySizeLimit: "256mb",
|
||||||
},
|
},
|
||||||
// Increase body size limit for API routes (file uploads) - 256MB to match backend limit
|
|
||||||
proxyClientMaxBodySize: "256mb",
|
|
||||||
middlewareClientMaxBodySize: "256mb",
|
middlewareClientMaxBodySize: "256mb",
|
||||||
|
// Limit parallel webpack workers to reduce peak memory during builds.
|
||||||
|
cpus: 2,
|
||||||
|
},
|
||||||
|
// Work around cssnano "Invalid array length" bug in Next.js's bundled
|
||||||
|
// cssnano-simple comment parser when processing very large CSS chunks.
|
||||||
|
// CSS is still bundled correctly; gzip handles most of the size savings anyway.
|
||||||
|
webpack: (config, { dev }) => {
|
||||||
|
if (!dev) {
|
||||||
|
// Next.js adds CssMinimizerPlugin internally (after user config), so we
|
||||||
|
// can't filter it from config.plugins. Instead, intercept the webpack
|
||||||
|
// compilation hooks and replace the buggy plugin's tap with a no-op.
|
||||||
|
config.plugins.push({
|
||||||
|
apply(compiler) {
|
||||||
|
compiler.hooks.compilation.tap(
|
||||||
|
"DisableCssMinimizer",
|
||||||
|
(compilation) => {
|
||||||
|
compilation.hooks.processAssets.intercept({
|
||||||
|
register: (tap) => {
|
||||||
|
if (tap.name === "CssMinimizerPlugin") {
|
||||||
|
return { ...tap, fn: async () => {} };
|
||||||
|
}
|
||||||
|
return tap;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return config;
|
||||||
},
|
},
|
||||||
images: {
|
images: {
|
||||||
domains: [
|
domains: [
|
||||||
@@ -54,9 +86,16 @@ const nextConfig = {
|
|||||||
transpilePackages: ["geist"],
|
transpilePackages: ["geist"],
|
||||||
};
|
};
|
||||||
|
|
||||||
const isDevelopmentBuild = process.env.NODE_ENV !== "production";
|
// Only run the Sentry webpack plugin when we can actually upload source maps
|
||||||
|
// (i.e. on Vercel with SENTRY_AUTH_TOKEN set). The Sentry *runtime* SDK
|
||||||
|
// (imported in app code) still captures errors without the plugin.
|
||||||
|
// Skipping the plugin saves ~1 GB of peak memory during `next build`.
|
||||||
|
const skipSentryPlugin =
|
||||||
|
process.env.NODE_ENV !== "production" ||
|
||||||
|
!enableSourceMaps ||
|
||||||
|
!process.env.SENTRY_AUTH_TOKEN;
|
||||||
|
|
||||||
export default isDevelopmentBuild
|
export default skipSentryPlugin
|
||||||
? nextConfig
|
? nextConfig
|
||||||
: withSentryConfig(nextConfig, {
|
: withSentryConfig(nextConfig, {
|
||||||
// For all available options, see:
|
// For all available options, see:
|
||||||
@@ -96,7 +135,7 @@ export default isDevelopmentBuild
|
|||||||
|
|
||||||
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
// This helps Sentry with sourcemaps... https://docs.sentry.io/platforms/javascript/guides/nextjs/sourcemaps/
|
||||||
sourcemaps: {
|
sourcemaps: {
|
||||||
disable: false,
|
disable: !enableSourceMaps,
|
||||||
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
assets: [".next/**/*.js", ".next/**/*.js.map"],
|
||||||
ignore: ["**/node_modules/**"],
|
ignore: ["**/node_modules/**"],
|
||||||
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
deleteSourcemapsAfterUpload: false, // Source is public anyway :)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "pnpm run generate:api:force && next dev --turbo",
|
"dev": "pnpm run generate:api:force && next dev --turbo",
|
||||||
"build": "next build",
|
"build": "cross-env NODE_OPTIONS=--max-old-space-size=16384 next build",
|
||||||
"start": "next start",
|
"start": "next start",
|
||||||
"start:standalone": "cd .next/standalone && node server.js",
|
"start:standalone": "cd .next/standalone && node server.js",
|
||||||
"lint": "next lint && prettier --check .",
|
"lint": "next lint && prettier --check .",
|
||||||
@@ -30,6 +30,7 @@
|
|||||||
"defaults"
|
"defaults"
|
||||||
],
|
],
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@ai-sdk/react": "3.0.61",
|
||||||
"@faker-js/faker": "10.0.0",
|
"@faker-js/faker": "10.0.0",
|
||||||
"@hookform/resolvers": "5.2.2",
|
"@hookform/resolvers": "5.2.2",
|
||||||
"@next/third-parties": "15.4.6",
|
"@next/third-parties": "15.4.6",
|
||||||
@@ -60,6 +61,10 @@
|
|||||||
"@rjsf/utils": "6.1.2",
|
"@rjsf/utils": "6.1.2",
|
||||||
"@rjsf/validator-ajv8": "6.1.2",
|
"@rjsf/validator-ajv8": "6.1.2",
|
||||||
"@sentry/nextjs": "10.27.0",
|
"@sentry/nextjs": "10.27.0",
|
||||||
|
"@streamdown/cjk": "1.0.1",
|
||||||
|
"@streamdown/code": "1.0.1",
|
||||||
|
"@streamdown/math": "1.0.1",
|
||||||
|
"@streamdown/mermaid": "1.0.1",
|
||||||
"@supabase/ssr": "0.7.0",
|
"@supabase/ssr": "0.7.0",
|
||||||
"@supabase/supabase-js": "2.78.0",
|
"@supabase/supabase-js": "2.78.0",
|
||||||
"@tanstack/react-query": "5.90.6",
|
"@tanstack/react-query": "5.90.6",
|
||||||
@@ -68,6 +73,7 @@
|
|||||||
"@vercel/analytics": "1.5.0",
|
"@vercel/analytics": "1.5.0",
|
||||||
"@vercel/speed-insights": "1.2.0",
|
"@vercel/speed-insights": "1.2.0",
|
||||||
"@xyflow/react": "12.9.2",
|
"@xyflow/react": "12.9.2",
|
||||||
|
"ai": "6.0.59",
|
||||||
"boring-avatars": "1.11.2",
|
"boring-avatars": "1.11.2",
|
||||||
"class-variance-authority": "0.7.1",
|
"class-variance-authority": "0.7.1",
|
||||||
"clsx": "2.1.1",
|
"clsx": "2.1.1",
|
||||||
@@ -87,7 +93,6 @@
|
|||||||
"launchdarkly-react-client-sdk": "3.9.0",
|
"launchdarkly-react-client-sdk": "3.9.0",
|
||||||
"lodash": "4.17.21",
|
"lodash": "4.17.21",
|
||||||
"lucide-react": "0.552.0",
|
"lucide-react": "0.552.0",
|
||||||
"moment": "2.30.1",
|
|
||||||
"next": "15.4.10",
|
"next": "15.4.10",
|
||||||
"next-themes": "0.4.6",
|
"next-themes": "0.4.6",
|
||||||
"nuqs": "2.7.2",
|
"nuqs": "2.7.2",
|
||||||
@@ -102,7 +107,7 @@
|
|||||||
"react-markdown": "9.0.3",
|
"react-markdown": "9.0.3",
|
||||||
"react-modal": "3.16.3",
|
"react-modal": "3.16.3",
|
||||||
"react-shepherd": "6.1.9",
|
"react-shepherd": "6.1.9",
|
||||||
"react-window": "1.8.11",
|
"react-window": "2.2.0",
|
||||||
"recharts": "3.3.0",
|
"recharts": "3.3.0",
|
||||||
"rehype-autolink-headings": "7.1.0",
|
"rehype-autolink-headings": "7.1.0",
|
||||||
"rehype-highlight": "7.0.2",
|
"rehype-highlight": "7.0.2",
|
||||||
@@ -112,9 +117,11 @@
|
|||||||
"remark-math": "6.0.0",
|
"remark-math": "6.0.0",
|
||||||
"shepherd.js": "14.5.1",
|
"shepherd.js": "14.5.1",
|
||||||
"sonner": "2.0.7",
|
"sonner": "2.0.7",
|
||||||
|
"streamdown": "2.1.0",
|
||||||
"tailwind-merge": "2.6.0",
|
"tailwind-merge": "2.6.0",
|
||||||
"tailwind-scrollbar": "3.1.0",
|
"tailwind-scrollbar": "3.1.0",
|
||||||
"tailwindcss-animate": "1.0.7",
|
"tailwindcss-animate": "1.0.7",
|
||||||
|
"use-stick-to-bottom": "1.1.2",
|
||||||
"uuid": "11.1.0",
|
"uuid": "11.1.0",
|
||||||
"vaul": "1.1.2",
|
"vaul": "1.1.2",
|
||||||
"zod": "3.25.76",
|
"zod": "3.25.76",
|
||||||
@@ -140,7 +147,7 @@
|
|||||||
"@types/react": "18.3.17",
|
"@types/react": "18.3.17",
|
||||||
"@types/react-dom": "18.3.5",
|
"@types/react-dom": "18.3.5",
|
||||||
"@types/react-modal": "3.16.3",
|
"@types/react-modal": "3.16.3",
|
||||||
"@types/react-window": "1.8.8",
|
"@types/react-window": "2.0.0",
|
||||||
"@vitejs/plugin-react": "5.1.2",
|
"@vitejs/plugin-react": "5.1.2",
|
||||||
"axe-playwright": "2.2.2",
|
"axe-playwright": "2.2.2",
|
||||||
"chromatic": "13.3.3",
|
"chromatic": "13.3.3",
|
||||||
@@ -172,7 +179,8 @@
|
|||||||
},
|
},
|
||||||
"pnpm": {
|
"pnpm": {
|
||||||
"overrides": {
|
"overrides": {
|
||||||
"@opentelemetry/instrumentation": "0.209.0"
|
"@opentelemetry/instrumentation": "0.209.0",
|
||||||
|
"lodash-es": "4.17.23"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
"packageManager": "pnpm@10.20.0+sha512.cf9998222162dd85864d0a8102e7892e7ba4ceadebbf5a31f9c2fce48dfce317a9c53b9f6464d1ef9042cba2e02ae02a9f7c143a2b438cd93c91840f0192b9dd"
|
||||||
|
|||||||
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
1218
autogpt_platform/frontend/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { getSchemaDefaultCredentials } from "../../helpers";
|
import { getSchemaDefaultCredentials } from "../../helpers";
|
||||||
@@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined;
|
|||||||
type Credentials = Record<string, Credential>;
|
type Credentials = Record<string, Credential>;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
agent: GraphMeta | null;
|
agent: GraphModel | null;
|
||||||
siblingInputs?: Record<string, any>;
|
siblingInputs?: Record<string, any>;
|
||||||
onCredentialsChange: (
|
onCredentialsChange: (
|
||||||
credentials: Record<string, CredentialsMetaInput>,
|
credentials: Record<string, CredentialsMetaInput>,
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types";
|
||||||
|
|
||||||
export function getCredentialFields(
|
export function getCredentialFields(
|
||||||
agent: GraphMeta | null,
|
agent: GraphModel | null,
|
||||||
): AgentCredentialsFields {
|
): AgentCredentialsFields {
|
||||||
if (!agent) return {};
|
if (!agent) return {};
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ import type {
|
|||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import type { InputValues } from "./types";
|
import type { InputValues } from "./types";
|
||||||
import { GraphMeta } from "@/app/api/__generated__/models/graphMeta";
|
import { GraphModel } from "@/app/api/__generated__/models/graphModel";
|
||||||
|
|
||||||
export function computeInitialAgentInputs(
|
export function computeInitialAgentInputs(
|
||||||
agent: GraphMeta | null,
|
agent: GraphModel | null,
|
||||||
existingInputs?: InputValues | null,
|
existingInputs?: InputValues | null,
|
||||||
): InputValues {
|
): InputValues {
|
||||||
const properties = agent?.input_schema?.properties || {};
|
const properties = agent?.input_schema?.properties || {};
|
||||||
@@ -29,7 +29,7 @@ export function computeInitialAgentInputs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type IsRunDisabledParams = {
|
type IsRunDisabledParams = {
|
||||||
agent: GraphMeta | null;
|
agent: GraphModel | null;
|
||||||
isRunning: boolean;
|
isRunning: boolean;
|
||||||
agentInputs: InputValues | null | undefined;
|
agentInputs: InputValues | null | undefined;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { debounce } from "lodash";
|
import debounce from "lodash/debounce";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
import { useBlockMenuStore } from "../../../../stores/blockMenuStore";
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
import { getQueryClient } from "@/lib/react-query/queryClient";
|
||||||
|
|||||||
@@ -70,10 +70,10 @@ export const HorizontalScroll: React.FC<HorizontalScrollAreaProps> = ({
|
|||||||
{children}
|
{children}
|
||||||
</div>
|
</div>
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-white via-white/80 to-white/0" />
|
<div className="pointer-events-none absolute inset-y-0 left-0 w-8 bg-gradient-to-r from-background via-background/80 to-background/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollRight && (
|
{canScrollRight && (
|
||||||
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-white via-white/80 to-white/0" />
|
<div className="pointer-events-none absolute inset-y-0 right-0 w-8 bg-gradient-to-l from-background via-background/80 to-background/0" />
|
||||||
)}
|
)}
|
||||||
{canScrollLeft && (
|
{canScrollLeft && (
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ import {
|
|||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||||
import jaro from "jaro-winkler";
|
import jaro from "jaro-winkler";
|
||||||
|
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
|
|
||||||
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
type _Block = Omit<Block, "inputSchema" | "outputSchema"> & {
|
||||||
uiKey?: string;
|
uiKey?: string;
|
||||||
@@ -107,6 +109,8 @@ export function BlocksControl({
|
|||||||
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
.filter((b) => b.uiType !== BlockUIType.AGENT)
|
||||||
.sort((a, b) => a.name.localeCompare(b.name));
|
.sort((a, b) => a.name.localeCompare(b.name));
|
||||||
|
|
||||||
|
// Agent blocks are created from GraphMeta which doesn't include schemas.
|
||||||
|
// Schemas will be fetched on-demand when the block is actually added.
|
||||||
const agentBlockList = flows
|
const agentBlockList = flows
|
||||||
.map((flow): _Block => {
|
.map((flow): _Block => {
|
||||||
return {
|
return {
|
||||||
@@ -116,8 +120,9 @@ export function BlocksControl({
|
|||||||
`Ver.${flow.version}` +
|
`Ver.${flow.version}` +
|
||||||
(flow.description ? ` | ${flow.description}` : ""),
|
(flow.description ? ` | ${flow.description}` : ""),
|
||||||
categories: [{ category: "AGENT", description: "" }],
|
categories: [{ category: "AGENT", description: "" }],
|
||||||
inputSchema: flow.input_schema,
|
// Empty schemas - will be populated when block is added
|
||||||
outputSchema: flow.output_schema,
|
inputSchema: { type: "object", properties: {} },
|
||||||
|
outputSchema: { type: "object", properties: {} },
|
||||||
staticOutput: false,
|
staticOutput: false,
|
||||||
uiType: BlockUIType.AGENT,
|
uiType: BlockUIType.AGENT,
|
||||||
costs: [],
|
costs: [],
|
||||||
@@ -125,8 +130,7 @@ export function BlocksControl({
|
|||||||
hardcodedValues: {
|
hardcodedValues: {
|
||||||
graph_id: flow.id,
|
graph_id: flow.id,
|
||||||
graph_version: flow.version,
|
graph_version: flow.version,
|
||||||
input_schema: flow.input_schema,
|
// Schemas will be fetched on-demand when block is added
|
||||||
output_schema: flow.output_schema,
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
@@ -182,6 +186,37 @@ export function BlocksControl({
|
|||||||
setSelectedCategory(null);
|
setSelectedCategory(null);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// Handler to add a block, fetching graph data on-demand for agent blocks
|
||||||
|
const handleAddBlock = useCallback(
|
||||||
|
async (block: _Block & { notAvailable: string | null }) => {
|
||||||
|
if (block.notAvailable) return;
|
||||||
|
|
||||||
|
// For agent blocks, fetch the full graph to get schemas
|
||||||
|
if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) {
|
||||||
|
const graphID = block.hardcodedValues.graph_id as string;
|
||||||
|
const graphVersion = block.hardcodedValues.graph_version as number;
|
||||||
|
const graphData = okData(
|
||||||
|
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (graphData) {
|
||||||
|
addBlock(block.id, block.name, {
|
||||||
|
...block.hardcodedValues,
|
||||||
|
input_schema: graphData.input_schema,
|
||||||
|
output_schema: graphData.output_schema,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Fallback: add without schemas (will be incomplete)
|
||||||
|
console.error("Failed to fetch graph data for agent block");
|
||||||
|
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
addBlock(block.id, block.name, block.hardcodedValues || {});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[addBlock],
|
||||||
|
);
|
||||||
|
|
||||||
// Extract unique categories from blocks
|
// Extract unique categories from blocks
|
||||||
const categories = useMemo(() => {
|
const categories = useMemo(() => {
|
||||||
return Array.from(
|
return Array.from(
|
||||||
@@ -303,10 +338,7 @@ export function BlocksControl({
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
onClick={() =>
|
onClick={() => handleAddBlock(block)}
|
||||||
!block.notAvailable &&
|
|
||||||
addBlock(block.id, block.name, block?.hardcodedValues || {})
|
|
||||||
}
|
|
||||||
title={block.notAvailable ?? undefined}
|
title={block.notAvailable ?? undefined}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
import { Clipboard, Maximize2 } from "lucide-react";
|
import { Clipboard, Maximize2 } from "lucide-react";
|
||||||
import React, { useState } from "react";
|
import React, { useMemo, useState } from "react";
|
||||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||||
import {
|
import {
|
||||||
@@ -11,6 +11,12 @@ import {
|
|||||||
TableHeader,
|
TableHeader,
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "../../../../../components/__legacy__/ui/table";
|
} from "../../../../../components/__legacy__/ui/table";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
import { useToast } from "../../../../../components/molecules/Toast/use-toast";
|
||||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||||
|
|
||||||
@@ -26,6 +32,9 @@ export default function DataTable({
|
|||||||
data,
|
data,
|
||||||
}: DataTableProps) {
|
}: DataTableProps) {
|
||||||
const { toast } = useToast();
|
const { toast } = useToast();
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
execId: string;
|
execId: string;
|
||||||
@@ -33,6 +42,15 @@ export default function DataTable({
|
|||||||
data: any[];
|
data: any[];
|
||||||
} | null>(null);
|
} | null>(null);
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const copyData = (pin: string, data: string) => {
|
const copyData = (pin: string, data: string) => {
|
||||||
navigator.clipboard.writeText(data).then(() => {
|
navigator.clipboard.writeText(data).then(() => {
|
||||||
toast({
|
toast({
|
||||||
@@ -102,15 +120,31 @@ export default function DataTable({
|
|||||||
<Clipboard size={18} />
|
<Clipboard size={18} />
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
{value.map((item, index) => (
|
{value.map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < value.length - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < value.length - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
|
|||||||
@@ -29,13 +29,17 @@ import "@xyflow/react/dist/style.css";
|
|||||||
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode";
|
||||||
import "./flow.css";
|
import "./flow.css";
|
||||||
import {
|
import {
|
||||||
|
BlockIORootSchema,
|
||||||
BlockUIType,
|
BlockUIType,
|
||||||
formatEdgeID,
|
formatEdgeID,
|
||||||
GraphExecutionID,
|
GraphExecutionID,
|
||||||
GraphID,
|
GraphID,
|
||||||
GraphMeta,
|
GraphMeta,
|
||||||
LibraryAgent,
|
LibraryAgent,
|
||||||
|
SpecialBlockID,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
|
import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
|
import { okData } from "@/app/api/helpers";
|
||||||
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||||
@@ -687,8 +691,94 @@ const FlowEditor: React.FC<{
|
|||||||
[getNode, updateNode, nodes],
|
[getNode, updateNode, nodes],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/* Shared helper to create and add a node */
|
||||||
|
const createAndAddNode = useCallback(
|
||||||
|
async (
|
||||||
|
blockID: string,
|
||||||
|
blockName: string,
|
||||||
|
hardcodedValues: Record<string, any>,
|
||||||
|
position: { x: number; y: number },
|
||||||
|
): Promise<CustomNode | null> => {
|
||||||
|
const nodeSchema = availableBlocks.find((node) => node.id === blockID);
|
||||||
|
if (!nodeSchema) {
|
||||||
|
console.error(`Schema not found for block ID: ${blockID}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For agent blocks, fetch the full graph to get schemas
|
||||||
|
let inputSchema: BlockIORootSchema = nodeSchema.inputSchema;
|
||||||
|
let outputSchema: BlockIORootSchema = nodeSchema.outputSchema;
|
||||||
|
let finalHardcodedValues = hardcodedValues;
|
||||||
|
|
||||||
|
if (blockID === SpecialBlockID.AGENT) {
|
||||||
|
const graphID = hardcodedValues.graph_id as string;
|
||||||
|
const graphVersion = hardcodedValues.graph_version as number;
|
||||||
|
const graphData = okData(
|
||||||
|
await getV1GetSpecificGraph(graphID, { version: graphVersion }),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (graphData) {
|
||||||
|
inputSchema = graphData.input_schema as BlockIORootSchema;
|
||||||
|
outputSchema = graphData.output_schema as BlockIORootSchema;
|
||||||
|
finalHardcodedValues = {
|
||||||
|
...hardcodedValues,
|
||||||
|
input_schema: graphData.input_schema,
|
||||||
|
output_schema: graphData.output_schema,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.error("Failed to fetch graph data for agent block");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const newNode: CustomNode = {
|
||||||
|
id: nodeId.toString(),
|
||||||
|
type: "custom",
|
||||||
|
position,
|
||||||
|
data: {
|
||||||
|
blockType: blockName,
|
||||||
|
blockCosts: nodeSchema.costs || [],
|
||||||
|
title: `${blockName} ${nodeId}`,
|
||||||
|
description: nodeSchema.description,
|
||||||
|
categories: nodeSchema.categories,
|
||||||
|
inputSchema: inputSchema,
|
||||||
|
outputSchema: outputSchema,
|
||||||
|
hardcodedValues: finalHardcodedValues,
|
||||||
|
connections: [],
|
||||||
|
isOutputOpen: false,
|
||||||
|
block_id: blockID,
|
||||||
|
isOutputStatic: nodeSchema.staticOutput,
|
||||||
|
uiType: nodeSchema.uiType,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
addNodes(newNode);
|
||||||
|
setNodeId((prevId) => prevId + 1);
|
||||||
|
clearNodesStatusAndOutput();
|
||||||
|
|
||||||
|
history.push({
|
||||||
|
type: "ADD_NODE",
|
||||||
|
payload: { node: { ...newNode, ...newNode.data } },
|
||||||
|
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
||||||
|
redo: () => addNodes(newNode),
|
||||||
|
});
|
||||||
|
|
||||||
|
return newNode;
|
||||||
|
},
|
||||||
|
[
|
||||||
|
availableBlocks,
|
||||||
|
nodeId,
|
||||||
|
addNodes,
|
||||||
|
deleteElements,
|
||||||
|
clearNodesStatusAndOutput,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
const addNode = useCallback(
|
const addNode = useCallback(
|
||||||
(blockId: string, nodeType: string, hardcodedValues: any = {}) => {
|
async (
|
||||||
|
blockId: string,
|
||||||
|
nodeType: string,
|
||||||
|
hardcodedValues: Record<string, any> = {},
|
||||||
|
) => {
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
||||||
if (!nodeSchema) {
|
if (!nodeSchema) {
|
||||||
console.error(`Schema not found for block ID: ${blockId}`);
|
console.error(`Schema not found for block ID: ${blockId}`);
|
||||||
@@ -707,73 +797,42 @@ const FlowEditor: React.FC<{
|
|||||||
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
// Alternative: We could also use D3 force, Intersection for this (React flow Pro examples)
|
||||||
|
|
||||||
const { x, y } = getViewport();
|
const { x, y } = getViewport();
|
||||||
const viewportCoordinates =
|
const position =
|
||||||
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
nodeDimensions && Object.keys(nodeDimensions).length > 0
|
||||||
? // we will get all the dimension of nodes, then store
|
? findNewlyAddedBlockCoordinates(
|
||||||
findNewlyAddedBlockCoordinates(
|
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500,
|
||||||
60,
|
60,
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
: // we will get all the dimension of nodes, then store
|
: {
|
||||||
{
|
|
||||||
x: window.innerWidth / 2 - x,
|
x: window.innerWidth / 2 - x,
|
||||||
y: window.innerHeight / 2 - y,
|
y: window.innerHeight / 2 - y,
|
||||||
};
|
};
|
||||||
|
|
||||||
const newNode: CustomNode = {
|
const newNode = await createAndAddNode(
|
||||||
id: nodeId.toString(),
|
blockId,
|
||||||
type: "custom",
|
nodeType,
|
||||||
position: viewportCoordinates, // Set the position to the calculated viewport center
|
hardcodedValues,
|
||||||
data: {
|
position,
|
||||||
blockType: nodeType,
|
);
|
||||||
blockCosts: nodeSchema.costs,
|
if (!newNode) return;
|
||||||
title: `${nodeType} ${nodeId}`,
|
|
||||||
description: nodeSchema.description,
|
|
||||||
categories: nodeSchema.categories,
|
|
||||||
inputSchema: nodeSchema.inputSchema,
|
|
||||||
outputSchema: nodeSchema.outputSchema,
|
|
||||||
hardcodedValues: hardcodedValues,
|
|
||||||
connections: [],
|
|
||||||
isOutputOpen: false,
|
|
||||||
block_id: blockId,
|
|
||||||
isOutputStatic: nodeSchema.staticOutput,
|
|
||||||
uiType: nodeSchema.uiType,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
addNodes(newNode);
|
|
||||||
setNodeId((prevId) => prevId + 1);
|
|
||||||
clearNodesStatusAndOutput(); // Clear status and output when a new node is added
|
|
||||||
|
|
||||||
setViewport(
|
setViewport(
|
||||||
{
|
{
|
||||||
// Rough estimate of the dimension of the node is: 500x400px.
|
x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
||||||
// Though we skip shifting the X, considering the block menu side-bar.
|
y: -position.y * 0.8 + (window.innerHeight - 400) / 2,
|
||||||
x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2,
|
|
||||||
y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2,
|
|
||||||
zoom: 0.8,
|
zoom: 0.8,
|
||||||
},
|
},
|
||||||
{ duration: 500 },
|
{ duration: 500 },
|
||||||
);
|
);
|
||||||
|
|
||||||
history.push({
|
|
||||||
type: "ADD_NODE",
|
|
||||||
payload: { node: { ...newNode, ...newNode.data } },
|
|
||||||
undo: () => deleteElements({ nodes: [{ id: newNode.id }] }),
|
|
||||||
redo: () => addNodes(newNode),
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
nodeId,
|
|
||||||
getViewport,
|
getViewport,
|
||||||
setViewport,
|
setViewport,
|
||||||
availableBlocks,
|
availableBlocks,
|
||||||
addNodes,
|
|
||||||
nodeDimensions,
|
nodeDimensions,
|
||||||
deleteElements,
|
createAndAddNode,
|
||||||
clearNodesStatusAndOutput,
|
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -920,7 +979,7 @@ const FlowEditor: React.FC<{
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const onDrop = useCallback(
|
const onDrop = useCallback(
|
||||||
(event: React.DragEvent) => {
|
async (event: React.DragEvent) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const blockData = event.dataTransfer.getData("application/reactflow");
|
const blockData = event.dataTransfer.getData("application/reactflow");
|
||||||
@@ -935,62 +994,17 @@ const FlowEditor: React.FC<{
|
|||||||
y: event.clientY,
|
y: event.clientY,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Find the block schema
|
await createAndAddNode(
|
||||||
const nodeSchema = availableBlocks.find((node) => node.id === blockId);
|
blockId,
|
||||||
if (!nodeSchema) {
|
blockName,
|
||||||
console.error(`Schema not found for block ID: ${blockId}`);
|
hardcodedValues || {},
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the new node at the drop position
|
|
||||||
const newNode: CustomNode = {
|
|
||||||
id: nodeId.toString(),
|
|
||||||
type: "custom",
|
|
||||||
position,
|
position,
|
||||||
data: {
|
);
|
||||||
blockType: blockName,
|
|
||||||
blockCosts: nodeSchema.costs || [],
|
|
||||||
title: `${blockName} ${nodeId}`,
|
|
||||||
description: nodeSchema.description,
|
|
||||||
categories: nodeSchema.categories,
|
|
||||||
inputSchema: nodeSchema.inputSchema,
|
|
||||||
outputSchema: nodeSchema.outputSchema,
|
|
||||||
hardcodedValues: hardcodedValues,
|
|
||||||
connections: [],
|
|
||||||
isOutputOpen: false,
|
|
||||||
block_id: blockId,
|
|
||||||
uiType: nodeSchema.uiType,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
history.push({
|
|
||||||
type: "ADD_NODE",
|
|
||||||
payload: { node: { ...newNode, ...newNode.data } },
|
|
||||||
undo: () => {
|
|
||||||
deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] });
|
|
||||||
},
|
|
||||||
redo: () => {
|
|
||||||
addNodes([newNode]);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
addNodes([newNode]);
|
|
||||||
clearNodesStatusAndOutput();
|
|
||||||
|
|
||||||
setNodeId((prevId) => prevId + 1);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to drop block:", error);
|
console.error("Failed to drop block:", error);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[screenToFlowPosition, createAndAddNode],
|
||||||
nodeId,
|
|
||||||
availableBlocks,
|
|
||||||
nodes,
|
|
||||||
edges,
|
|
||||||
addNodes,
|
|
||||||
screenToFlowPosition,
|
|
||||||
deleteElements,
|
|
||||||
clearNodesStatusAndOutput,
|
|
||||||
],
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const buildContextValue: BuilderContextType = useMemo(
|
const buildContextValue: BuilderContextType = useMemo(
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
import React, { useContext, useState } from "react";
|
import React, { useContext, useMemo, useState } from "react";
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Maximize2 } from "lucide-react";
|
import { Maximize2 } from "lucide-react";
|
||||||
import * as Separator from "@radix-ui/react-separator";
|
import * as Separator from "@radix-ui/react-separator";
|
||||||
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
||||||
|
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||||
|
import {
|
||||||
|
globalRegistry,
|
||||||
|
OutputItem,
|
||||||
|
} from "@/components/contextual/OutputRenderers";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
|
||||||
import { beautifyString } from "@/lib/utils";
|
import { beautifyString } from "@/lib/utils";
|
||||||
|
|
||||||
@@ -21,6 +27,9 @@ export default function NodeOutputs({
|
|||||||
data,
|
data,
|
||||||
}: NodeOutputsProps) {
|
}: NodeOutputsProps) {
|
||||||
const builderContext = useContext(BuilderContext);
|
const builderContext = useContext(BuilderContext);
|
||||||
|
const enableEnhancedOutputHandling = useGetFlag(
|
||||||
|
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
||||||
|
);
|
||||||
|
|
||||||
const [expandedDialog, setExpandedDialog] = useState<{
|
const [expandedDialog, setExpandedDialog] = useState<{
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
@@ -37,6 +46,15 @@ export default function NodeOutputs({
|
|||||||
|
|
||||||
const { getNodeTitle } = builderContext;
|
const { getNodeTitle } = builderContext;
|
||||||
|
|
||||||
|
// Prepare renderers for each item when enhanced mode is enabled
|
||||||
|
const getItemRenderer = useMemo(() => {
|
||||||
|
if (!enableEnhancedOutputHandling) return null;
|
||||||
|
return (item: unknown) => {
|
||||||
|
const metadata: OutputMetadata = {};
|
||||||
|
return globalRegistry.getRenderer(item, metadata);
|
||||||
|
};
|
||||||
|
}, [enableEnhancedOutputHandling]);
|
||||||
|
|
||||||
const getBeautifiedPinName = (pin: string) => {
|
const getBeautifiedPinName = (pin: string) => {
|
||||||
if (!pin.startsWith("tools_^_")) {
|
if (!pin.startsWith("tools_^_")) {
|
||||||
return beautifyString(pin);
|
return beautifyString(pin);
|
||||||
@@ -87,15 +105,31 @@ export default function NodeOutputs({
|
|||||||
<div className="mt-2">
|
<div className="mt-2">
|
||||||
<strong className="mr-2">Data:</strong>
|
<strong className="mr-2">Data:</strong>
|
||||||
<div className="mt-1">
|
<div className="mt-1">
|
||||||
{dataArray.slice(0, 10).map((item, index) => (
|
{dataArray.slice(0, 10).map((item, index) => {
|
||||||
<React.Fragment key={index}>
|
const renderer = getItemRenderer?.(item);
|
||||||
<ContentRenderer
|
if (enableEnhancedOutputHandling && renderer) {
|
||||||
value={item}
|
const metadata: OutputMetadata = {};
|
||||||
truncateLongData={truncateLongData}
|
return (
|
||||||
/>
|
<React.Fragment key={index}>
|
||||||
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
<OutputItem
|
||||||
</React.Fragment>
|
value={item}
|
||||||
))}
|
metadata={metadata}
|
||||||
|
renderer={renderer}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<React.Fragment key={index}>
|
||||||
|
<ContentRenderer
|
||||||
|
value={item}
|
||||||
|
truncateLongData={truncateLongData}
|
||||||
|
/>
|
||||||
|
{index < Math.min(dataArray.length, 10) - 1 && ", "}
|
||||||
|
</React.Fragment>
|
||||||
|
);
|
||||||
|
})}
|
||||||
{dataArray.length > 10 && (
|
{dataArray.length > 10 && (
|
||||||
<span style={{ color: "#888" }}>
|
<span style={{ color: "#888" }}>
|
||||||
<br />
|
<br />
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user